Skip to content

Commit

Permalink
Use logarithm of the sum of acceptance probs
Browse files Browse the repository at this point in the history
It is more numerically stable for small values of the probability than
summing the probabilities directly.
  • Loading branch information
rlouf committed Feb 8, 2022
1 parent 1de1eb5 commit ed25a07
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 7 deletions.
16 changes: 10 additions & 6 deletions aehmc/proposals.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,13 @@ def update(initial_energy, state):
is_transition_divergent = at.abs_(delta_energy) > divergence_threshold

weight = delta_energy
p_accept = at.clip(at.exp(delta_energy), 0.0, 1.0)
log_p_accept = at.where(
at.gt(delta_energy, 0),
at.as_tensor(0, dtype=delta_energy.dtype),
delta_energy,
)

return (state, new_energy, weight, p_accept), is_transition_divergent
return (state, new_energy, weight, log_p_accept), is_transition_divergent

return update

Expand Down Expand Up @@ -127,11 +131,11 @@ def maybe_update_proposal(
do_accept: bool, proposal: ProposalStateType, new_proposal: ProposalStateType
) -> ProposalStateType:
"""Return either proposal depending on the boolean `do_accept`"""
state, energy, weight, sum_p_accept = proposal
new_state, new_energy, new_weight, new_sum_p_accept = new_proposal
state, energy, weight, log_sum_p_accept = proposal
new_state, new_energy, new_weight, new_log_sum_p_accept = new_proposal

updated_weight = at.logaddexp(weight, new_weight)
updated_sum_p_accept = sum_p_accept + new_sum_p_accept
updated_log_sum_p_accept = at.logaddexp(log_sum_p_accept, new_log_sum_p_accept)

updated_q = at.where(do_accept, new_state[0], state[0])
updated_p = at.where(do_accept, new_state[1], state[1])
Expand All @@ -143,5 +147,5 @@ def maybe_update_proposal(
(updated_q, updated_p, updated_potential_energy, updated_potential_energy_grad),
updated_energy,
updated_weight,
updated_sum_p_accept,
updated_log_sum_p_accept,
)
2 changes: 1 addition & 1 deletion aehmc/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def expand_once(
# Compute the pseudo-acceptance probability for the NUTS algorithm.
# It can be understood as the average acceptance probability MC would give to
# the states explored during the final expansion.
acceptance_probability = new_proposal[3] / subtrajectory_length
acceptance_probability = at.exp(new_proposal[3]) / subtrajectory_length

# Update the proposal.
# If the termination criterion is reached in the subtree or if a
Expand Down

0 comments on commit ed25a07

Please sign in to comment.