Skip to content

Commit

Permalink
update test_sampling with barker api
Browse files Browse the repository at this point in the history
the mass matrix is now an optional argument in barker.
  • Loading branch information
ismael-mendoza committed Sep 24, 2024
1 parent be32caf commit fd35a51
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,9 +496,7 @@ def test_barker(self):
)
logposterior_fn = lambda x: logposterior_fn_(**x)

inv_mass_matrix = jnp.eye(2) # no effect on original test

barker = blackjax.barker_proposal(logposterior_fn, 1e-1, inv_mass_matrix)
barker = blackjax.barker_proposal(logposterior_fn, 1e-1)
state = barker.init({"coefs": 1.0, "log_scale": 1.0})

_, states = run_inference_algorithm(
Expand Down Expand Up @@ -889,7 +887,7 @@ def test_mala(self):
@chex.all_variants(with_pmap=False)
def test_barker(self):
inference_algorithm = blackjax.barker_proposal(
self.normal_logprob, step_size=1.5, inverse_mass_matrix=jnp.eye(1)
self.normal_logprob, step_size=1.5
)
initial_state = inference_algorithm.init(jnp.array(1.0))
self.univariate_normal_test_case(
Expand Down Expand Up @@ -926,7 +924,7 @@ def test_barker(self):
},
{
"algorithm": blackjax.barker_proposal,
"parameters": {"step_size": 0.5, "inverse_mass_matrix": jnp.eye(2)},
"parameters": {"step_size": 0.5},
"is_mass_matrix_diagonal": None,
},
]
Expand Down

0 comments on commit fd35a51

Please sign in to comment.