From 31f276fb0daaa9ff4e63389d2b3daed2d870cce3 Mon Sep 17 00:00:00 2001 From: Chris Fonnesbeck Date: Fri, 11 Oct 2024 14:05:22 -0500 Subject: [PATCH] Allow for passing of backend and gradient_backend to nutpie (#7535) * Allow for passing of backend and gradient_backend to nutpie * Extract nutpie compiler args explicitly --- pymc/sampling/mcmc.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index 4b26bb51c8..4ee79607b7 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -305,7 +305,14 @@ def _sample_external_nuts( "`var_names` are currently ignored by the nutpie sampler", UserWarning, ) - compiled_model = nutpie.compile_pymc_model(model) + compile_kwargs = {} + for kwarg in ("backend", "gradient_backend"): + if kwarg in nuts_sampler_kwargs: + compile_kwargs[kwarg] = nuts_sampler_kwargs.pop(kwarg) + compiled_model = nutpie.compile_pymc_model( + model, + **compile_kwargs, + ) t_start = time.time() idata = nutpie.sample( compiled_model,