From 5495f4e492ec6c309cb7e03e98468a2e62e8802d Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Mon, 11 Nov 2024 19:46:37 +0000 Subject: [PATCH] Lint flax.nnx.while_loop docstring --- flax/nnx/transforms/iteration.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/flax/nnx/transforms/iteration.py b/flax/nnx/transforms/iteration.py index 20366c3e1f..7eab09126b 100644 --- a/flax/nnx/transforms/iteration.py +++ b/flax/nnx/transforms/iteration.py @@ -1406,10 +1406,10 @@ def __call__(self, pure_val): def while_loop(cond_fun: tp.Callable[[T], tp.Any], body_fun: tp.Callable[[T], T], init_val: T) -> T: - """NNX transform of `jax.lax.while_loop `_. + """A Flax NNX transformation of `jax.lax.while_loop `_. Caution: for the NNX internal reference tracing mechanism to work, you cannot - change the variable reference structure of `init_val` inside `body_fun`. + change the variable reference structure of ``init_val`` inside ``body_fun``. Example:: @@ -1427,12 +1427,12 @@ def while_loop(cond_fun: tp.Callable[[T], tp.Any], Args: - cond_fun: a function for the continue condition of the while loop, taking a - single input of type `T` and outputting a boolean. - body_fun: a function that takes an input of type `T` and outputs an `T`. - Note that both data and modules of `T` must have the same reference + cond_fun: A function for the continue condition of the while loop, taking a + single input of type ``T`` and outputting a boolean. + body_fun: A function that takes an input of type ``T`` and outputs an ``T``. + Note that both data and modules of ``T`` must have the same reference structure between inputs and outputs. - init_val: the initial input for cond_fun and body_fun. Must be of type `T`. + init_val: The initial input for ``cond_fun`` and ``body_fun``. Must be of type ``T``. """ @@ -1537,4 +1537,4 @@ def fori_loop(lower: int, upper: int, ForiLoopBodyFn(body_fun), pure_init_val, unroll=unroll) out = extract.from_tree(pure_out, ctxtag='fori_loop') - return out \ No newline at end of file + return out