diff --git a/tests/ensemble_test.py b/tests/ensemble_test.py index a9fb02235..837d02e74 100644 --- a/tests/ensemble_test.py +++ b/tests/ensemble_test.py @@ -4,7 +4,7 @@ from __future__ import annotations import pytest -from torch import eye, ones, zeros +from torch import eye, ones, randn_like, tensor, zeros from torch.distributions import MultivariateNormal from sbi.inference import NLE_A, NPE_C, NRE_A @@ -149,3 +149,40 @@ def simulator(theta): samples = posterior.sample_batched((10,), ones(x_o_batch_dim, num_dim)) assert samples.shape == (10, x_o_batch_dim, num_dim), "Sample shape wrong" + + +@pytest.mark.parametrize( + "weights, expected_exception", + [ + (None, None), + ([0.3, 0.7], None), + (tensor([0.4, 0.6]), None), + ({"w1": 0.5, "w2": 0.5}, TypeError), + (0.5, TypeError), + ((0.5, 0.5), TypeError), + ], +) +def test_ensemble_posterior_weights(weights, expected_exception): + """Test EnsemblePosterior weight handling for valid and invalid formats.""" + num_dim = 2 + ensemble_size = 2 + num_simulations = 50 + x_o = zeros(1, num_dim) + + prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim)) + + posteriors = [] + for _ in range(ensemble_size): + theta = prior.sample((num_simulations,)) + x = theta + 0.1 * randn_like(theta) + inferer = NPE_C(prior) + inferer.append_simulations(theta, x).train(max_num_epochs=1) + posteriors.append(inferer.build_posterior()) + + if expected_exception: + with pytest.raises(expected_exception): + EnsemblePosterior(posteriors, weights=weights) + else: + posterior = EnsemblePosterior(posteriors, weights=weights) + posterior.set_default_x(x_o) + _ = posterior.sample((2,))