From 19a4abf063041ffc51109abe0207c24ff6c7c49e Mon Sep 17 00:00:00 2001 From: IvyZX Date: Tue, 19 Nov 2024 15:12:57 -0800 Subject: [PATCH] Fix fori_loop and while_loop on multiple modules --- flax/nnx/transforms/iteration.py | 6 ++---- tests/nnx/transforms_test.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index 6318bbe0b5..c9a3c1c4be 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -1329,8 +1329,8 @@ def __call__(self, pure_val): def _add_fake_index_mapping(tree: tp.Any): + global_index_mapping = {} # for the whole context, over all inputs def per_node_state(ns: extract.NodeStates | tp.Any): - global_index_mapping = {} if not isinstance(ns, extract.NodeStates) or not isinstance( ns._graphdef, graph.NodeDef ): @@ -1339,10 +1339,8 @@ def per_node_state(ns: extract.NodeStates | tp.Any): def per_node_def(nd: graph.NodeDef | graph.NodeRef): if nd.index >= 0: global_index_mapping[nd.index] = nd.index - if isinstance(nd, graph.NodeRef): return - for sub_nd in nd.subgraphs.values(): per_node_def(sub_nd) for l in nd.leaves.values(): @@ -1480,7 +1478,7 @@ def __call__(self, i, pure_val): "have the same reference and pytree structure, but they differ. " "If the mismatch comes from `index_mapping` field, you might " "have modified reference structure within the body function, " - "which is not allowed." + "which is not allowed. " f"Detail of the mismatch: \n {str(e)}") raise ValueError(msg) diff --git a/tests/nnx/transforms_test.py b/tests/nnx/transforms_test.py index c0ded037e9..736da9acf0 100644 --- a/tests/nnx/transforms_test.py +++ b/tests/nnx/transforms_test.py @@ -2969,6 +2969,22 @@ def rollout(d: D): d.a.params.value, np.full((10,), 10, dtype=int) ) + def test_loops_multiple_modules(self): + class Foo(nnx.Module): + def __init__(self): + self.param = nnx.Param(jnp.zeros((1,))) + def __call__(self, x): + return self.param + + def loop_fn(inputs): + return inputs + while_loop_fn = lambda inputs: (*loop_fn(inputs[:-1]), inputs[-1]-1) + fori_loop_fn = lambda i, inputs: loop_fn(inputs) + a = Foo() + b = Foo() + nnx.while_loop(lambda input: input[-1] > 0, while_loop_fn, (a, b, 2)) + nnx.fori_loop(0, 2, fori_loop_fn, (a, b)) + class TestSplitMergeInputs(absltest.TestCase): def test_split_inputs(self):