diff --git a/aehmc/trajectory.py b/aehmc/trajectory.py index 3cfeb6a..e6f35f3 100644 --- a/aehmc/trajectory.py +++ b/aehmc/trajectory.py @@ -387,17 +387,16 @@ def expand_once( # the states explored during the final expansion. acceptance_probability = at.exp(new_proposal[3]) / subtrajectory_length - # Update the proposal. - # If the termination criterion is reached in the subtree or if a - # divergence occurs we reject this subtree's proposal. We - # nevertheless update the sum of the logarithm of the acceptance - # probabilities to serve as an estimate for dual averaging. - updated_weight = at.logaddexp(proposal[2], new_proposal[2]) + # Update the proposal + # + # We do not accept proposals that come from diverging or turning subtrajectories. + # However the definition of the acceptance probability is such that the + # acceptance probability needs to be computed across the entire trajectory. updated_proposal = ( proposal[0], proposal[1], - updated_weight, - new_proposal[3] + proposal[3], + proposal[2], + at.logaddexp(new_proposal[3], proposal[3]), ) sampled_proposal = where_proposal(