Skip to content

Commit

Permalink
test: add xfailing test for MDN bug.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Aug 19, 2024
1 parent 74fae49 commit aff94b0
Showing 1 changed file with 74 additions and 1 deletion.
75 changes: 74 additions & 1 deletion tests/posterior_nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
SNRE_C,
DirectPosterior,
)
from sbi.simulators.linear_gaussian import diagonal_linear_gaussian
from sbi.simulators.linear_gaussian import (
diagonal_linear_gaussian,
linear_gaussian,
true_posterior_linear_gaussian_mvn_prior,
)
from sbi.utils.diagnostics_utils import get_posterior_samples_on_batch
from tests.test_utils import check_c2st


@pytest.mark.parametrize("snpe_method", [SNPE_A, SNPE_C])
Expand Down Expand Up @@ -204,3 +210,70 @@ def test_batched_mcmc_sample_log_prob_with_different_x(
assert torch.allclose(
samples_m, samples_sep_m, atol=0.2, rtol=0.2
), "Batched sampling is not consistent with separate sampling."


@pytest.mark.slow
@pytest.mark.parametrize(
"density_estimator",
[
pytest.param(
"mdn",
marks=pytest.mark.xfail(
raises=AssertionError, reason="Due to MDN bug in pyknos", strict=True
),
),
"maf",
"zuko_nsf",
],
)
def test_batched_sampling_and_logprob_accuracy(density_estimator: str):
"""Test with two different observations and compare to sequential methods."""
num_dim = 2
num_simulations = 2000
num_samples = 1000
sample_shape = (num_samples,)
xos = torch.stack((-1.0 * ones(num_dim), 1.0 * ones(num_dim)))
num_xos = xos.shape[0]

prior = MultivariateNormal(loc=zeros(num_dim), covariance_matrix=eye(num_dim))
likelihood_shift = -1.0 * ones(num_dim)
likelihood_cov = 0.3 * eye(num_dim)
prior_mean = zeros(num_dim)
prior_cov = eye(num_dim)
prior = MultivariateNormal(loc=prior_mean, covariance_matrix=prior_cov)

def simulator(theta):
return linear_gaussian(theta, likelihood_shift, likelihood_cov)

inference = SNPE_C(
prior=prior, show_progress_bars=False, density_estimator=density_estimator
)
theta = prior.sample((num_simulations,))
x = simulator(theta)
posterior_estimator = inference.append_simulations(theta, x).train()

posterior = DirectPosterior(posterior_estimator=posterior_estimator, prior=prior)

samples_batched = get_posterior_samples_on_batch(
xos, posterior, sample_shape, use_batched_sampling=True
)
log_probs_batched = posterior.log_prob_batched(samples_batched, xos)

# check c2st for each xos
for idx in range(0, num_xos):
gt_posterior = true_posterior_linear_gaussian_mvn_prior(
xos[idx], likelihood_shift, likelihood_cov, prior_mean, prior_cov
)
target_samples = gt_posterior.sample((num_samples,))
check_c2st(
target_samples,
samples_batched[:, idx],
alg=f"c2st-batch-vs-non-batch-{density_estimator}-x-idx{idx}",
)

target_log_probs = gt_posterior.log_prob(samples_batched[:, idx])
log_probs = posterior.log_prob(samples_batched[:, idx], xos[idx])
assert torch.allclose(log_probs, log_probs_batched[:, idx])
assert torch.allclose(
target_log_probs.exp(), log_probs.exp(), atol=0.4, rtol=0.4
), "Batched log probs are not consistent with non-batched log probs."

0 comments on commit aff94b0

Please sign in to comment.