Skip to content

Commit

Permalink
Merge pull request #4371 from 8bitmp3:update-nnx-while_loop
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 696921986
  • Loading branch information
Flax Authors committed Nov 15, 2024
2 parents d13f047 + 5495f4e commit 9147a7c
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions flax/nnx/transforms/iteration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1412,10 +1412,10 @@ def __call__(self, pure_val):
def while_loop(cond_fun: tp.Callable[[T], tp.Any],
body_fun: tp.Callable[[T], T],
init_val: T) -> T:
"""NNX transform of `jax.lax.while_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html>`_.
"""A Flax NNX transformation of `jax.lax.while_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.while_loop.html>`_.
Caution: for the NNX internal reference tracing mechanism to work, you cannot
change the variable reference structure of `init_val` inside `body_fun`.
change the variable reference structure of ``init_val`` inside ``body_fun``.
Example::
Expand All @@ -1433,12 +1433,12 @@ def while_loop(cond_fun: tp.Callable[[T], tp.Any],
Args:
cond_fun: a function for the continue condition of the while loop, taking a
single input of type `T` and outputting a boolean.
body_fun: a function that takes an input of type `T` and outputs an `T`.
Note that both data and modules of `T` must have the same reference
cond_fun: A function for the continue condition of the while loop, taking a
single input of type ``T`` and outputting a boolean.
body_fun: A function that takes an input of type ``T`` and outputs an ``T``.
Note that both data and modules of ``T`` must have the same reference
structure between inputs and outputs.
init_val: the initial input for cond_fun and body_fun. Must be of type `T`.
init_val: The initial input for ``cond_fun`` and ``body_fun``. Must be of type ``T``.
"""

Expand Down Expand Up @@ -1543,4 +1543,4 @@ def fori_loop(lower: int, upper: int,
ForiLoopBodyFn(body_fun), pure_init_val,
unroll=unroll)
out = extract.from_tree(pure_out, ctxtag='fori_loop')
return out
return out

0 comments on commit 9147a7c

Please sign in to comment.