Skip to content

Commit

Permalink
Lint flax.nnx.jit docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 authored Nov 11, 2024
1 parent d31f290 commit a57cabe
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions flax/nnx/transforms/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,13 @@ def jit(
abstracted_axes: tp.Optional[tp.Any] = None,
) -> F | tp.Callable[[F], F]:
"""
Lifted version of ``jax.jit`` that can handle Modules / graph nodes as
A "lifted" version of ``jax.jit`` that can handle ``nnx.Modules`` / graph nodes as
arguments.
Learn more in `Flax NNX vs JAX Transformations <https://flax.readthedocs.io/en/latest/guides/jax_and_nnx_transforms.html>`_.
Args:
fun: Function to be jitted. ``fun`` should be a pure function, as
fun: A function to be JIT-compiled. ``fun`` should be a pure function, as
side-effects may only be executed once.
The arguments and return value of ``fun`` should be arrays,
Expand All @@ -186,7 +188,7 @@ def jit(
JAX keeps a weak reference to ``fun`` for use as a compilation cache key,
so the object ``fun`` must be weakly-referenceable. Most :class:`Callable`
objects will already satisfy this requirement.
in_shardings: Pytree of structure matching that of arguments to ``fun``,
in_shardings: A JAX pytree of structure matching that of arguments to ``fun``,
with all actual arguments replaced by resource assignment specifications.
It is also valid to specify a pytree prefix (e.g. one value in place of a
whole subtree), in which case the leaves get broadcast to all values in
Expand All @@ -209,9 +211,9 @@ def jit(
The size of every dimension has to be a multiple of the total number of
resources assigned to it. This is similar to pjit's in_shardings.
out_shardings: Like ``in_shardings``, but specifies resource
assignment for function outputs. This is similar to pjit's
out_shardings.
out_shardings: Similar to ``in_shardings``, but specifies resource
assignment for function outputs. This is similar to JAX ``pjit``
``out_shardings``.
The ``out_shardings`` argument is optional. If not specified, :py:func:`jax.jit`
will use GSPMD's sharding propagation to figure out what the sharding of the
Expand All @@ -223,7 +225,7 @@ def jit(
any Python object.
Static arguments should be hashable, meaning both ``__hash__`` and
``__eq__`` are implemented, and immutable. Calling the jitted function
``__eq__`` are implemented, and immutable. Calling the JIT-compiled function
with different values for these constants will trigger recompilation.
Arguments that are not arrays or containers thereof must be marked as
static.
Expand Down Expand Up @@ -262,18 +264,18 @@ def jit(
be donated.
For more details on buffer donation see the
`FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
`JAX FAQ <https://jax.readthedocs.io/en/latest/faq.html#buffer-donation>`_.
donate_argnames: An optional string or collection of strings specifying
which named arguments are donated to the computation. See the
comment on ``donate_argnums`` for details. If not
provided but ``donate_argnums`` is set, the default is based on calling
``inspect.signature(fun)`` to find corresponding named arguments.
keep_unused: If `False` (the default), arguments that JAX determines to be
unused by `fun` *may* be dropped from resulting compiled XLA executables.
keep_unused: If ``False`` (the default), arguments that JAX determines to be
unused by ``fun`` *may* be dropped from resulting compiled XLA executables.
Such arguments will not be transferred to the device nor provided to the
underlying executable. If `True`, unused arguments will not be pruned.
device: This is an experimental feature and the API is likely to change.
Optional, the Device the jitted function will run on. (Available devices
Optional, the Device the JIT-compiled function will run on. (Available devices
can be retrieved via :py:func:`jax.devices`.) The default is inherited
from XLA's DeviceAssignment logic and is usually to use
``jax.devices()[0]``.
Expand All @@ -282,7 +284,7 @@ def jit(
``'tpu'``.
inline: Specify whether this function should be inlined into enclosing
jaxprs (rather than being represented as an application of the xla_call
primitive with its own subjaxpr). Default False.
primitive with its own subjaxpr). Default ``False``.
Returns:
A wrapped version of ``fun``, set up for just-in-time compilation.
Expand Down

0 comments on commit a57cabe

Please sign in to comment.