Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add grid_map utility for managing JAX parallelization/vectorization #355

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

DanPuzzuoli
Copy link
Collaborator

@DanPuzzuoli DanPuzzuoli commented Apr 10, 2024

Summary

grid_map is mapping utility for mapping a function over a "grid" of argument values. E.g. for two arrays a and b,

grid_map(f, a, b)[i, j] = f(a[i], b[j])

This works for an arbitrary number of arguments, and more generally works on JAX PyTrees using standard conventions for mapping over PyTrees (explained in the function doc string). I.e. for a and b general PyTrees with appropriate leaf shape (all leaves must have equal leading dimension size for the mapping makes sense), the above expression holds so long as we interpret

v[idx] = tree_map(lambda x: x[idx], v)

for v being any of the indexed objects in the previous expression. I.e. v[idx] is the PyTree resulting from indexing every leaf of v with idx (and we are slightly abusing notation to allow idx to be a multi-index).

Aside from the usefulness of the above form of mapping, the main point of grid_map is to offer control over how the evaluations of f get parallelized. Under the hood, it utilizes JAX's xmap to execute the mapping using a combination of device parallelization and vectorization (which they describe in their documentation as an "interpolation" between pmap and vmap). It makes natural default choices based on the device types, and a user can directly control these with optional arguments.

For Dynamics this will be of use for internally controlling parallelization in the package, whether using CPU or GPU, and generally will be useful for users to have direct access to.

Details and comments

This is currently a work in progress. This helper function was written to control parallelization/mapping in some research projects, and its design/implementation will need to be revisited for integration into Dynamics.

General to do:

  • Needs to be rewritten to insulate JAX imports, and raise an error if JAX isn't installed.

Design questions:

  • Should we add global settings that a user can specify to control "default" parallelization behaviour (max_vmap_size and devices)? The other option is to allow the user to specify these optional arguments anywhere "internal" parallelization is an option (e.g. JAX-based pulse simulation).
  • I recall the non_jax_argnums argument being kind of clunky - can revisit whether we want to include this.

Update (23/4/24):

  • It appears that the xmap has been deprecated, so the inner-workings of grid_map will need to be rethought. I think shard map may be the natural replacement.

@CLAassistant
Copy link

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants