From 230ff26475209594d32d21bed0962962324871c3 Mon Sep 17 00:00:00 2001 From: Ismael Mendoza Date: Wed, 2 Oct 2024 14:59:42 -0500 Subject: [PATCH] fix test to not use _barker_sample_nd --- tests/mcmc/test_barker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/mcmc/test_barker.py b/tests/mcmc/test_barker.py index 375b0b10a..65571b61c 100644 --- a/tests/mcmc/test_barker.py +++ b/tests/mcmc/test_barker.py @@ -9,7 +9,7 @@ import blackjax from blackjax.mcmc import metrics -from blackjax.mcmc.barker import _barker_pdf, _barker_sample_nd +from blackjax.mcmc.barker import _barker_pdf, _barker_sample from blackjax.util import run_inference_algorithm @@ -27,7 +27,7 @@ def test_nd(self, seed): metric = metrics.default_metric(jnp.eye(4)) keys = jax.random.split(key, n_samples) - samples = jax.vmap(lambda k: _barker_sample_nd(k, m, a, scale, metric))(keys) + samples = jax.vmap(lambda k: _barker_sample(k, m, a, scale, metric))(keys) # Check that the emprical mean and the mean computed as sum(x * p(x) dx) are close _test_samples_vs_pdf(samples, lambda x: _barker_pdf(x, m, a, scale))