-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow for passing of backend and gradient_backend to nutpie #7535
Conversation
Minimal implementation of #7498 |
pymc/sampling/mcmc.py
Outdated
compiled_model = nutpie.compile_pymc_model(model) | ||
compiled_model = nutpie.compile_pymc_model( | ||
model, | ||
backend=nuts_sampler_kwargs.pop("backend", None), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Depending on the outcome of pymc-devs/nutpie#151 it may be:
backend=nuts_sampler_kwargs.pop("backend", None), | |
backend=nuts_sampler_kwargs.pop("backend", "numba"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As it happens, I submitted a nitpick PR to fix this. It should be None
because None
is handled in the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise you can do something like:
compile_kwargs = {}
for kwarg in ("backend", "gradient_backend"):
if kwarg in nuts_sampler_kwargs:
compile_kwargs[kwarg] = nuts_sampler_kwargs.pop(kwarg)
compile_pymc_model(..., **compile_kwargs)
This way you don't have to guess the default of nutpie nor change anything there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do. The arg default still needs to be changed in nutpie one way or the other, as I discuss there.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7535 +/- ##
==========================================
- Coverage 92.88% 92.85% -0.03%
==========================================
Files 105 105
Lines 17587 17591 +4
==========================================
Hits 16335 16335
- Misses 1252 1256 +4
|
…s#7535) * Allow for passing of backend and gradient_backend to nutpie * Extract nutpie compiler args explicitly
Description
There is no means for passing
backend
andgradient_backend
tocompile_pymc_model
fromsample
, so that JAX can be used as a backend. This just looks for these arguments innuts_sampler_kwargs
and passes them if they exist.Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7535.org.readthedocs.build/en/7535/