Skip to content

Commit

Permalink
bug fix for MNLE on gpu.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Aug 13, 2022
1 parent 5736080 commit d81dc81
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
8 changes: 5 additions & 3 deletions sbi/neural_nets/mnle.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,16 +361,18 @@ def log_prob_iid(self, x: Tensor, theta: Tensor) -> Tensor:
x_cont_repeated, x_disc_repeated = _separate_x(x_repeated)
x_cont, x_disc = _separate_x(x)

log_prob_per_cat = torch.zeros(self.discrete_net.num_categories, batch_size)
# repeat categories for parameters
repeated_categories = torch.repeat_interleave(
torch.arange(self.discrete_net.num_categories - 1), batch_size, dim=0
)
# repeat parameters for categories
repeated_theta = theta.repeat(self.discrete_net.num_categories - 1, 1)
log_prob_per_cat = torch.zeros(self.discrete_net.num_categories, batch_size).to(
net_device
)
log_prob_per_cat[:-1, :] = self.discrete_net.log_prob(
repeated_categories,
repeated_theta,
repeated_categories.to(net_device),
repeated_theta.to(net_device),
).reshape(-1, batch_size)
# infer the last category logprob from sum to one.
log_prob_per_cat[-1, :] = torch.log(1 - log_prob_per_cat[:-1, :].exp().sum(0))
Expand Down
2 changes: 1 addition & 1 deletion tests/mnle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_mnle_on_device(device):

# Test sampling on device.
posterior = trainer.build_posterior()
posterior.sample((1,), x=x[0], show_progress_bars=False)
posterior.sample((1,), x=x[0], show_progress_bars=False, mcmc_method="nuts")


@pytest.mark.parametrize("sampler", ("mcmc", "rejection", "vi"))
Expand Down

0 comments on commit d81dc81

Please sign in to comment.