Skip to content

Commit

Permalink
Add flax.nnx.remat docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 authored Nov 11, 2024
1 parent 3d8bf7b commit 404c12f
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions flax/nnx/transforms/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,17 +873,17 @@ def remat(
),
)
)

"""A "lifted" version of the `jax.checkpoint <https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html>`__
(a.k.a. ``jax.remat``).
``flax.nnx.remat``, similar to ``jax.checkpoint`` can provide control over, for
example, how ``flax.nnx.grad`` values are computed and saved during the forward pass versus
how they are recomputed during the backward pass, trading off memory and FLOPs.
example, how ``flax.nnx.grad`` values are computed and saved during the forward pass versus
how they are recomputed during the backward pass, trading off memory and FLOPs.
Learn more in `Flax NNX vs JAX Transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`_.
To learn about ``jax.remat``, go to JAX's
`fundamentals of jax.checkpoint <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#fundamentals-of-jax-checkpoint>`_
and `practical notes <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#practical-notes>`_.
`fundamentals of jax.checkpoint <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#fundamentals-of-jax-checkpoint>`_
and `practical notes <https://jax.readthedocs.io/en/latest/notebooks/autodiff_remat.html#practical-notes>`_.
"""

0 comments on commit 404c12f

Please sign in to comment.