Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TransformState doesn't work with torch.vmap #83

Open
SamDuffield opened this issue May 1, 2024 · 1 comment · Fixed by #106
Open

TransformState doesn't work with torch.vmap #83

SamDuffield opened this issue May 1, 2024 · 1 comment · Fixed by #106
Labels
enhancement New feature or request (beyond just a new method)

Comments

@SamDuffield
Copy link
Contributor

Currently you cannot do something like

n_ensemble = 10

sghmc_transform = posteriors.sgmcmc.sghmc.build(
    log_posterior, lr=5e-2, alpha=1.0, temperature=0.0
)

states = torch.vmap(sghmc_transform.init, randomness='different')(torch.randn(n_ensemble, 2))
# ValueError: vmap(functools.partial(<function init at 0x3027cf420>, momenta=None), ...): 
#`functools.partial(<function init at 0x3027cf420>, momenta=None)` must only return Tensors,
# got type <class 'posteriors.sgmcmc.sghmc.SGHMCState'>. Did you mean to set out_dim= to None for output?

I think we need to register TransformState as a pytree node with torch.utils._pytree following pytorch/functorch#475

@SamDuffield SamDuffield added enhancement New feature or request (beyond just a new method) help wanted Extra attention is needed labels May 1, 2024
@SamDuffield
Copy link
Contributor Author

To fully support this I think we might need to enforce aux to be a TensorTree

@SamDuffield SamDuffield reopened this Aug 2, 2024
@SamDuffield SamDuffield removed the help wanted Extra attention is needed label Aug 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request (beyond just a new method)
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant