diff --git a/sbi/samplers/rejection/rejection.py b/sbi/samplers/rejection/rejection.py index 5c1d5ffc5..35492ca47 100644 --- a/sbi/samplers/rejection/rejection.py +++ b/sbi/samplers/rejection/rejection.py @@ -362,6 +362,4 @@ def accept_reject_sample( samples.shape[0] == num_samples ), "Number of accepted samples must match required samples." - # NOTE: Restriction prior does currently require a float as return for the - # acceptance rate, which is why we for now also return the minimum acceptance rate. - return samples, as_tensor(min_acceptance_rate) + return samples, as_tensor(acceptance_rate) diff --git a/sbi/utils/restriction_estimator.py b/sbi/utils/restriction_estimator.py index 83bd436f7..e895d0d1a 100644 --- a/sbi/utils/restriction_estimator.py +++ b/sbi/utils/restriction_estimator.py @@ -692,6 +692,11 @@ def sample( max_sampling_batch_size=max_sampling_batch_size, alternative_method="sample_with='sir'", ) + # NOTE: This currently requires a float acceptance rate. A previous version + # of accept_reject_sample returned a float. In favour to batched sampling + # it now returns a tensor. + acceptance_rate = acceptance_rate.min().item() + if save_acceptance_rate: self.acceptance_rate = torch.as_tensor(acceptance_rate) if print_rejected_frac: diff --git a/tests/posterior_nn_test.py b/tests/posterior_nn_test.py index 6b0352f2f..b236e21e0 100644 --- a/tests/posterior_nn_test.py +++ b/tests/posterior_nn_test.py @@ -6,7 +6,7 @@ import pytest import torch from torch import eye, ones, zeros -from torch.distributions import MultivariateNormal +from torch.distributions import Independent, MultivariateNormal, Uniform from sbi.inference import ( NLE_A, @@ -98,13 +98,20 @@ def test_importance_posterior_sample_log_prob(snplre_method: type): @pytest.mark.parametrize("snpe_method", [NPE_A, NPE_C]) @pytest.mark.parametrize("x_o_batch_dim", (0, 1, 2)) +@pytest.mark.parametrize("prior", ("mvn", "uniform")) def test_batched_sample_log_prob_with_different_x( - snpe_method: type, x_o_batch_dim: bool + snpe_method: type, + x_o_batch_dim: bool, + prior: str, ): num_dim = 2 num_simulations = 1000 - prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim)) + # We also want to test on bounded support! Which will invoke leakage correction. + if prior == "mvn": + prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim)) + elif prior == "uniform": + prior = Independent(Uniform(-1.0 * ones(num_dim), 1.0 * ones(num_dim)), 1) simulator = diagonal_linear_gaussian inference = snpe_method(prior=prior) @@ -116,6 +123,7 @@ def test_batched_sample_log_prob_with_different_x( posterior = DirectPosterior(posterior_estimator=posterior_estimator, prior=prior) + torch.manual_seed(0) samples = posterior.sample_batched((10,), x_o) batched_log_probs = posterior.log_prob_batched(samples, x_o) @@ -126,6 +134,20 @@ def test_batched_sample_log_prob_with_different_x( ), "Sample shape wrong" assert batched_log_probs.shape == (10, max(x_o_batch_dim, 1)), "logprob shape wrong" + # Test consistency with non-batched log_prob + # NOTE: Leakage factor is a MC estimate, so we need to relax the tolerance here. + if x_o_batch_dim == 0: + log_probs = posterior.log_prob(samples, x=x_o) + assert torch.allclose( + log_probs, batched_log_probs[:, 0], atol=1e-1, rtol=1e-1 + ), "Batched log probs different from non-batched log probs" + else: + for idx in range(x_o_batch_dim): + log_probs = posterior.log_prob(samples[:, idx], x=x_o[idx]) + assert torch.allclose( + log_probs, batched_log_probs[:, idx], atol=1e-1, rtol=1e-1 + ), "Batched log probs different from non-batched log probs" + @pytest.mark.mcmc @pytest.mark.parametrize("snlre_method", [NLE_A, NRE_A, NRE_B, NRE_C, NPE_C])