diff --git a/tests/sampling/test_mcmc_external.py b/tests/sampling/test_mcmc_external.py index e5c9393c01..23c95b161a 100644 --- a/tests/sampling/test_mcmc_external.py +++ b/tests/sampling/test_mcmc_external.py @@ -118,6 +118,7 @@ def test_numba_backend_options(pymc_model, recwarn, backend): def test_invalid_nutpie_backend_raises(pymc_model): + pytest.importorskip("nutpie") with pytest.raises(ValueError, match='Expected one of "numba" or "jax"; found "invalid"'): with pymc_model: sample(nuts_sampler="nutpie[invalid]", random_seed=123, chains=2, tune=500, draws=500)