Add grid_map utility for managing JAX parallelization/vectorization #355
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
grid_map
is mapping utility for mapping a function over a "grid" of argument values. E.g. for two arraysa
andb
,This works for an arbitrary number of arguments, and more generally works on JAX PyTrees using standard conventions for mapping over PyTrees (explained in the function doc string). I.e. for
a
andb
generalPyTree
s with appropriate leaf shape (all leaves must have equal leading dimension size for the mapping makes sense), the above expression holds so long as we interpretfor
v
being any of the indexed objects in the previous expression. I.e.v[idx]
is the PyTree resulting from indexing every leaf ofv
withidx
(and we are slightly abusing notation to allowidx
to be a multi-index).Aside from the usefulness of the above form of mapping, the main point of
grid_map
is to offer control over how the evaluations off
get parallelized. Under the hood, it utilizes JAX'sxmap
to execute the mapping using a combination of device parallelization and vectorization (which they describe in their documentation as an "interpolation" betweenpmap
andvmap
). It makes natural default choices based on the device types, and a user can directly control these with optional arguments.For Dynamics this will be of use for internally controlling parallelization in the package, whether using CPU or GPU, and generally will be useful for users to have direct access to.
Details and comments
This is currently a work in progress. This helper function was written to control parallelization/mapping in some research projects, and its design/implementation will need to be revisited for integration into Dynamics.
General to do:
Design questions:
max_vmap_size
anddevices
)? The other option is to allow the user to specify these optional arguments anywhere "internal" parallelization is an option (e.g. JAX-based pulse simulation).non_jax_argnums
argument being kind of clunky - can revisit whether we want to include this.Update (23/4/24):
xmap
has been deprecated, so the inner-workings ofgrid_map
will need to be rethought. I think shard map may be the natural replacement.