TransformState
doesn't work with torch.vmap
#83
Labels
enhancement
New feature or request (beyond just a new method)
TransformState
doesn't work with torch.vmap
#83
Currently you cannot do something like
I think we need to register
TransformState
as a pytree node withtorch.utils._pytree
following pytorch/functorch#475The text was updated successfully, but these errors were encountered: