diff --git a/aehmc/proposals.py b/aehmc/proposals.py index a92cfd6..b3e5e6b 100644 --- a/aehmc/proposals.py +++ b/aehmc/proposals.py @@ -43,9 +43,13 @@ def update(initial_energy, state): is_transition_divergent = at.abs_(delta_energy) > divergence_threshold weight = delta_energy - p_accept = at.clip(at.exp(delta_energy), 0.0, 1.0) + log_p_accept = at.where( + at.gt(delta_energy, 0), + at.as_tensor(0, dtype=delta_energy.dtype), + delta_energy, + ) - return (state, new_energy, weight, p_accept), is_transition_divergent + return (state, new_energy, weight, log_p_accept), is_transition_divergent return update @@ -127,11 +131,11 @@ def maybe_update_proposal( do_accept: bool, proposal: ProposalStateType, new_proposal: ProposalStateType ) -> ProposalStateType: """Return either proposal depending on the boolean `do_accept`""" - state, energy, weight, sum_p_accept = proposal - new_state, new_energy, new_weight, new_sum_p_accept = new_proposal + state, energy, weight, log_sum_p_accept = proposal + new_state, new_energy, new_weight, new_log_sum_p_accept = new_proposal updated_weight = at.logaddexp(weight, new_weight) - updated_sum_p_accept = sum_p_accept + new_sum_p_accept + updated_log_sum_p_accept = at.logaddexp(log_sum_p_accept, new_log_sum_p_accept) updated_q = at.where(do_accept, new_state[0], state[0]) updated_p = at.where(do_accept, new_state[1], state[1]) @@ -143,5 +147,5 @@ def maybe_update_proposal( (updated_q, updated_p, updated_potential_energy, updated_potential_energy_grad), updated_energy, updated_weight, - updated_sum_p_accept, + updated_log_sum_p_accept, ) diff --git a/aehmc/trajectory.py b/aehmc/trajectory.py index 8177822..e7446ff 100644 --- a/aehmc/trajectory.py +++ b/aehmc/trajectory.py @@ -387,7 +387,7 @@ def expand_once( # Compute the pseudo-acceptance probability for the NUTS algorithm. # It can be understood as the average acceptance probability MC would give to # the states explored during the final expansion. - acceptance_probability = new_proposal[3] / subtrajectory_length + acceptance_probability = at.exp(new_proposal[3]) / subtrajectory_length # Update the proposal. # If the termination criterion is reached in the subtree or if a