Skip to content

Commit

Permalink
Merge pull request #4390 from IvyZX:loop
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 698482509
  • Loading branch information
Flax Authors committed Nov 20, 2024
2 parents d89c955 + 19a4abf commit 4e25898
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
6 changes: 2 additions & 4 deletions flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,8 +1329,8 @@ def __call__(self, pure_val):


def _add_fake_index_mapping(tree: tp.Any):
global_index_mapping = {} # for the whole context, over all inputs
def per_node_state(ns: extract.NodeStates | tp.Any):
global_index_mapping = {}
if not isinstance(ns, extract.NodeStates) or not isinstance(
ns._graphdef, graph.NodeDef
):
Expand All @@ -1339,10 +1339,8 @@ def per_node_state(ns: extract.NodeStates | tp.Any):
def per_node_def(nd: graph.NodeDef | graph.NodeRef):
if nd.index >= 0:
global_index_mapping[nd.index] = nd.index

if isinstance(nd, graph.NodeRef):
return

for sub_nd in nd.subgraphs.values():
per_node_def(sub_nd)
for l in nd.leaves.values():
Expand Down Expand Up @@ -1480,7 +1478,7 @@ def __call__(self, i, pure_val):
"have the same reference and pytree structure, but they differ. "
"If the mismatch comes from `index_mapping` field, you might "
"have modified reference structure within the body function, "
"which is not allowed."
"which is not allowed. "
f"Detail of the mismatch: \n {str(e)}")
raise ValueError(msg)

Expand Down
16 changes: 16 additions & 0 deletions tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2969,6 +2969,22 @@ def rollout(d: D):
d.a.params.value, np.full((10,), 10, dtype=int)
)

def test_loops_multiple_modules(self):
class Foo(nnx.Module):
def __init__(self):
self.param = nnx.Param(jnp.zeros((1,)))
def __call__(self, x):
return self.param

def loop_fn(inputs):
return inputs
while_loop_fn = lambda inputs: (*loop_fn(inputs[:-1]), inputs[-1]-1)
fori_loop_fn = lambda i, inputs: loop_fn(inputs)
a = Foo()
b = Foo()
nnx.while_loop(lambda input: input[-1] > 0, while_loop_fn, (a, b, 2))
nnx.fori_loop(0, 2, fori_loop_fn, (a, b))


class TestSplitMergeInputs(absltest.TestCase):
def test_split_inputs(self):
Expand Down

0 comments on commit 4e25898

Please sign in to comment.