From f77297f556a8ddf1c017ce7fffea0816d8b086de Mon Sep 17 00:00:00 2001 From: ksnxr <70186663+ksnxr@users.noreply.github.com> Date: Sun, 31 Mar 2024 10:54:38 +0300 Subject: [PATCH] Fix MALA transition energy (#653) * Fix MALA transition energy * Use a different logic. --- blackjax/mcmc/mala.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/blackjax/mcmc/mala.py b/blackjax/mcmc/mala.py index 9690bc7f5..f6dd7c106 100644 --- a/blackjax/mcmc/mala.py +++ b/blackjax/mcmc/mala.py @@ -80,14 +80,14 @@ def transition_energy(state, new_state, step_size): """Transition energy to go from `state` to `new_state`""" theta = jax.tree_util.tree_map( lambda new_x, x, g: new_x - x - step_size * g, - new_state.position, state.position, - state.logdensity_grad, + new_state.position, + new_state.logdensity_grad, ) theta_dot = jax.tree_util.tree_reduce( operator.add, jax.tree_util.tree_map(lambda x: jnp.sum(x * x), theta) ) - return -state.logdensity + 0.25 * (1.0 / step_size) * theta_dot + return -new_state.logdensity + 0.25 * (1.0 / step_size) * theta_dot compute_acceptance_ratio = proposal.compute_asymmetric_acceptance_ratio( transition_energy