Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix fori_loop and while_loop on multiple modules #4390

Merged
merged 1 commit into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1329,8 +1329,8 @@ def __call__(self, pure_val):


def _add_fake_index_mapping(tree: tp.Any):
global_index_mapping = {} # for the whole context, over all inputs
def per_node_state(ns: extract.NodeStates | tp.Any):
global_index_mapping = {}
if not isinstance(ns, extract.NodeStates) or not isinstance(
ns._graphdef, graph.NodeDef
):
Expand All @@ -1339,10 +1339,8 @@ def per_node_state(ns: extract.NodeStates | tp.Any):
def per_node_def(nd: graph.NodeDef | graph.NodeRef):
if nd.index >= 0:
global_index_mapping[nd.index] = nd.index

if isinstance(nd, graph.NodeRef):
return

for sub_nd in nd.subgraphs.values():
per_node_def(sub_nd)
for l in nd.leaves.values():
Expand Down Expand Up @@ -1480,7 +1478,7 @@ def __call__(self, i, pure_val):
"have the same reference and pytree structure, but they differ. "
"If the mismatch comes from `index_mapping` field, you might "
"have modified reference structure within the body function, "
"which is not allowed."
"which is not allowed. "
f"Detail of the mismatch: \n {str(e)}")
raise ValueError(msg)

Expand Down
16 changes: 16 additions & 0 deletions tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2969,6 +2969,22 @@ def rollout(d: D):
d.a.params.value, np.full((10,), 10, dtype=int)
)

def test_loops_multiple_modules(self):
class Foo(nnx.Module):
def __init__(self):
self.param = nnx.Param(jnp.zeros((1,)))
def __call__(self, x):
return self.param

def loop_fn(inputs):
return inputs
while_loop_fn = lambda inputs: (*loop_fn(inputs[:-1]), inputs[-1]-1)
fori_loop_fn = lambda i, inputs: loop_fn(inputs)
a = Foo()
b = Foo()
nnx.while_loop(lambda input: input[-1] > 0, while_loop_fn, (a, b, 2))
nnx.fori_loop(0, 2, fori_loop_fn, (a, b))


class TestSplitMergeInputs(absltest.TestCase):
def test_split_inputs(self):
Expand Down
Loading