From b87f401c6f6714991fd52b343972c496c34b27e2 Mon Sep 17 00:00:00 2001 From: Jan Boelts Date: Thu, 4 Apr 2024 08:37:45 +0200 Subject: [PATCH] fix deprecated nuts and hmc kwargs. --- tests/inference_on_device_test.py | 2 +- tests/linearGaussian_snle_test.py | 8 +++----- tests/linearGaussian_snre_test.py | 4 ++-- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/inference_on_device_test.py b/tests/inference_on_device_test.py index dd87da969..7df1ef1d6 100644 --- a/tests/inference_on_device_test.py +++ b/tests/inference_on_device_test.py @@ -58,7 +58,7 @@ pytest.param(SNRE_B, "resnet", "slice", marks=pytest.mark.mcmc), (SNRE_C, "resnet", "rejection"), (SNRE_C, "resnet", "importance"), - pytest.param(SNRE_C, "resnet", "nuts", marks=pytest.mark.mcmc), + pytest.param(SNRE_C, "resnet", "nuts_pymc", marks=pytest.mark.mcmc), ], ) @pytest.mark.parametrize( diff --git a/tests/linearGaussian_snle_test.py b/tests/linearGaussian_snle_test.py index 8a9fea093..51440d034 100644 --- a/tests/linearGaussian_snle_test.py +++ b/tests/linearGaussian_snle_test.py @@ -406,11 +406,9 @@ def simulator(theta): pytest.param("slice_np", "uniform", marks=pytest.mark.mcmc), pytest.param("slice_np_vectorized", "gaussian", marks=pytest.mark.mcmc), pytest.param("slice_np_vectorized", "uniform", marks=pytest.mark.mcmc), - pytest.param("slice", "gaussian", marks=pytest.mark.mcmc), - pytest.param("slice", "uniform", marks=pytest.mark.mcmc), - pytest.param("nuts", "gaussian", marks=pytest.mark.mcmc), - pytest.param("nuts", "uniform", marks=pytest.mark.mcmc), - pytest.param("hmc", "gaussian", marks=pytest.mark.mcmc), + pytest.param("nuts_pymc", "gaussian", marks=pytest.mark.mcmc), + pytest.param("nuts_pyro", "uniform", marks=pytest.mark.mcmc), + pytest.param("hmc_pymc", "gaussian", marks=pytest.mark.mcmc), ("rejection", "uniform"), ("rejection", "gaussian"), ("rKL", "uniform"), diff --git a/tests/linearGaussian_snre_test.py b/tests/linearGaussian_snre_test.py index 086c7ce73..34338e7a6 100644 --- a/tests/linearGaussian_snre_test.py +++ b/tests/linearGaussian_snre_test.py @@ -328,8 +328,8 @@ def simulator(theta): pytest.param("slice_np_vectorized", "uniform", marks=pytest.mark.mcmc), pytest.param("slice", "gaussian", marks=pytest.mark.mcmc), pytest.param("slice", "uniform", marks=pytest.mark.mcmc), - pytest.param("nuts", "gaussian", marks=pytest.mark.mcmc), - pytest.param("nuts", "uniform", marks=pytest.mark.mcmc), + pytest.param("nuts_pymc", "gaussian", marks=pytest.mark.mcmc), + pytest.param("nuts_pyro", "uniform", marks=pytest.mark.mcmc), pytest.param("hmc", "gaussian", marks=pytest.mark.mcmc), ("rejection", "uniform"), ("rejection", "gaussian"),