Skip to content

Commit

Permalink
fix deprecated nuts and hmc kwargs.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Apr 4, 2024
1 parent 6e0e98a commit b87f401
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 8 deletions.
2 changes: 1 addition & 1 deletion tests/inference_on_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
pytest.param(SNRE_B, "resnet", "slice", marks=pytest.mark.mcmc),
(SNRE_C, "resnet", "rejection"),
(SNRE_C, "resnet", "importance"),
pytest.param(SNRE_C, "resnet", "nuts", marks=pytest.mark.mcmc),
pytest.param(SNRE_C, "resnet", "nuts_pymc", marks=pytest.mark.mcmc),
],
)
@pytest.mark.parametrize(
Expand Down
8 changes: 3 additions & 5 deletions tests/linearGaussian_snle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,9 @@ def simulator(theta):
pytest.param("slice_np", "uniform", marks=pytest.mark.mcmc),
pytest.param("slice_np_vectorized", "gaussian", marks=pytest.mark.mcmc),
pytest.param("slice_np_vectorized", "uniform", marks=pytest.mark.mcmc),
pytest.param("slice", "gaussian", marks=pytest.mark.mcmc),
pytest.param("slice", "uniform", marks=pytest.mark.mcmc),
pytest.param("nuts", "gaussian", marks=pytest.mark.mcmc),
pytest.param("nuts", "uniform", marks=pytest.mark.mcmc),
pytest.param("hmc", "gaussian", marks=pytest.mark.mcmc),
pytest.param("nuts_pymc", "gaussian", marks=pytest.mark.mcmc),
pytest.param("nuts_pyro", "uniform", marks=pytest.mark.mcmc),
pytest.param("hmc_pymc", "gaussian", marks=pytest.mark.mcmc),
("rejection", "uniform"),
("rejection", "gaussian"),
("rKL", "uniform"),
Expand Down
4 changes: 2 additions & 2 deletions tests/linearGaussian_snre_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,8 @@ def simulator(theta):
pytest.param("slice_np_vectorized", "uniform", marks=pytest.mark.mcmc),
pytest.param("slice", "gaussian", marks=pytest.mark.mcmc),
pytest.param("slice", "uniform", marks=pytest.mark.mcmc),
pytest.param("nuts", "gaussian", marks=pytest.mark.mcmc),
pytest.param("nuts", "uniform", marks=pytest.mark.mcmc),
pytest.param("nuts_pymc", "gaussian", marks=pytest.mark.mcmc),
pytest.param("nuts_pyro", "uniform", marks=pytest.mark.mcmc),
pytest.param("hmc", "gaussian", marks=pytest.mark.mcmc),
("rejection", "uniform"),
("rejection", "gaussian"),
Expand Down

0 comments on commit b87f401

Please sign in to comment.