diff --git a/flax/nnx/transforms/autodiff.py b/flax/nnx/transforms/autodiff.py index 7547504c5a..861869ae40 100644 --- a/flax/nnx/transforms/autodiff.py +++ b/flax/nnx/transforms/autodiff.py @@ -873,17 +873,17 @@ def remat( ), ) ) - """A "lifted" version of the `jax.checkpoint `__ (a.k.a. ``jax.remat``). ``flax.nnx.remat``, similar to ``jax.checkpoint`` can provide control over, for - example, how ``flax.nnx.grad`` values are computed and saved during the forward pass versus - how they are recomputed during the backward pass, trading off memory and FLOPs. + example, how ``flax.nnx.grad`` values are computed and saved during the forward pass versus + how they are recomputed during the backward pass, trading off memory and FLOPs. Learn more in `Flax NNX vs JAX Transformations `_. To learn about ``jax.remat``, go to JAX's - `fundamentals of jax.checkpoint `_ - and `practical notes `_. + `fundamentals of jax.checkpoint `_ + and `practical notes `_. """ +