Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

minor clean up #635

Merged
merged 2 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion blackjax/mcmc/termination.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import jax
import jax.numpy as jnp

from blackjax.mcmc.metrics import CheckTurning
from blackjax.types import Array


Expand All @@ -27,7 +28,7 @@ class IterativeUTurnState(NamedTuple):
idx_max: int


def iterative_uturn_numpyro(is_turning):
def iterative_uturn_numpyro(is_turning: CheckTurning):
"""Numpyro style dynamic U-Turn criterion."""

def new_state(chain_state, max_num_doublings) -> IterativeUTurnState:
Expand Down
27 changes: 16 additions & 11 deletions blackjax/mcmc/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
# limitations under the License.
"""Procedures to build trajectories for algorithms in the HMC family.

To propose a new state, algorithms in the HMC family generally proceed by :cite:p:`betancourt2017conceptual`:
To propose a new state, algorithms in the HMC family generally proceed by
:cite:p:`betancourt2017conceptual`:

1. Sampling a trajectory starting from the initial point;
2. Sampling a new state from this sampled trajectory.
Expand Down Expand Up @@ -299,10 +300,11 @@ def dynamic_recursive_integration(
"""Integrate a trajectory and update the proposal recursively in Python
until the termination criterion is met.

This is the implementation of Algorithm 6 from :cite:p:`hoffman2014no` with multinomial sampling.
The implemenation here is mostly for validating the progressive implementation
to make sure the two are equivalent. The recursive implementation should not
be used for actually sampling as it cannot be jitted and thus likely slow.
This is the implementation of Algorithm 6 from :cite:p:`hoffman2014no` with
multinomial sampling. The implemenation here is mostly for validating the
progressive implementation to make sure the two are equivalent. The recursive
implementation should not be used for actually sampling as it cannot be jitted and
thus likely slow.

Parameters
----------
Expand All @@ -313,9 +315,11 @@ def dynamic_recursive_integration(
uturn_check_fn
Determines whether the termination criterion has been met.
divergence_threshold
Value of the difference of energy between two consecutive states above which we say a transition is divergent.
Value of the difference of energy between two consecutive states above which we
say a transition is divergent.
use_robust_uturn_check
Bool to indicate whether to perform additional U turn check between two trajectory.
Bool to indicate whether to perform additional U turn check between two
trajectory.

"""
_, generate_proposal = proposal_generator(hmc_energy(kinetic_energy))
Expand Down Expand Up @@ -348,7 +352,8 @@ def buildtree_integrate(
step_size
The step size of the symplectic integrator.
initial_energy
Initial energy H0 of the HMC step (not to confused with the initial energy of the subtree)
Initial energy H0 of the HMC step (not to confused with the initial energy
of the subtree)

"""
if tree_depth == 0:
Expand Down Expand Up @@ -561,9 +566,9 @@ def expand_once(loop_state):
# Update the proposal
#
# We do not accept proposals that come from diverging or turning
# subtrajectories. However the definition of the acceptance
# probability is such that the acceptance probability needs to be
# computed across the entire trajectory.
# subtrajectories. However the definition of the acceptance probability is
# such that the acceptance probability needs to be computed across the
# entire trajectory.
def update_sum_log_p_accept(inputs):
_, proposal, new_proposal = inputs
return Proposal(
Expand Down
2 changes: 1 addition & 1 deletion requirements-doc.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ myst_nb>=1.0.0
numba
numpyro
optax
oryx @ git+https://github.com/jax-ml/oryx.git@main # remove after oryx release
oryx
pymc
scikit-learn
sphinx
Expand Down