diff --git a/blackjax/mcmc/termination.py b/blackjax/mcmc/termination.py index 8432054e4..24e17c3a5 100644 --- a/blackjax/mcmc/termination.py +++ b/blackjax/mcmc/termination.py @@ -17,6 +17,7 @@ import jax import jax.numpy as jnp +from blackjax.mcmc.metrics import CheckTurning from blackjax.types import Array @@ -27,7 +28,7 @@ class IterativeUTurnState(NamedTuple): idx_max: int -def iterative_uturn_numpyro(is_turning): +def iterative_uturn_numpyro(is_turning: CheckTurning): """Numpyro style dynamic U-Turn criterion.""" def new_state(chain_state, max_num_doublings) -> IterativeUTurnState: diff --git a/blackjax/mcmc/trajectory.py b/blackjax/mcmc/trajectory.py index 6338acc2b..6deeb9bef 100644 --- a/blackjax/mcmc/trajectory.py +++ b/blackjax/mcmc/trajectory.py @@ -13,7 +13,8 @@ # limitations under the License. """Procedures to build trajectories for algorithms in the HMC family. -To propose a new state, algorithms in the HMC family generally proceed by :cite:p:`betancourt2017conceptual`: +To propose a new state, algorithms in the HMC family generally proceed by +:cite:p:`betancourt2017conceptual`: 1. Sampling a trajectory starting from the initial point; 2. Sampling a new state from this sampled trajectory. @@ -299,10 +300,11 @@ def dynamic_recursive_integration( """Integrate a trajectory and update the proposal recursively in Python until the termination criterion is met. - This is the implementation of Algorithm 6 from :cite:p:`hoffman2014no` with multinomial sampling. - The implemenation here is mostly for validating the progressive implementation - to make sure the two are equivalent. The recursive implementation should not - be used for actually sampling as it cannot be jitted and thus likely slow. + This is the implementation of Algorithm 6 from :cite:p:`hoffman2014no` with + multinomial sampling. The implemenation here is mostly for validating the + progressive implementation to make sure the two are equivalent. The recursive + implementation should not be used for actually sampling as it cannot be jitted and + thus likely slow. Parameters ---------- @@ -313,9 +315,11 @@ def dynamic_recursive_integration( uturn_check_fn Determines whether the termination criterion has been met. divergence_threshold - Value of the difference of energy between two consecutive states above which we say a transition is divergent. + Value of the difference of energy between two consecutive states above which we + say a transition is divergent. use_robust_uturn_check - Bool to indicate whether to perform additional U turn check between two trajectory. + Bool to indicate whether to perform additional U turn check between two + trajectory. """ _, generate_proposal = proposal_generator(hmc_energy(kinetic_energy)) @@ -348,7 +352,8 @@ def buildtree_integrate( step_size The step size of the symplectic integrator. initial_energy - Initial energy H0 of the HMC step (not to confused with the initial energy of the subtree) + Initial energy H0 of the HMC step (not to confused with the initial energy + of the subtree) """ if tree_depth == 0: @@ -561,9 +566,9 @@ def expand_once(loop_state): # Update the proposal # # We do not accept proposals that come from diverging or turning - # subtrajectories. However the definition of the acceptance - # probability is such that the acceptance probability needs to be - # computed across the entire trajectory. + # subtrajectories. However the definition of the acceptance probability is + # such that the acceptance probability needs to be computed across the + # entire trajectory. def update_sum_log_p_accept(inputs): _, proposal, new_proposal = inputs return Proposal( diff --git a/requirements-doc.txt b/requirements-doc.txt index dc8781831..338073a88 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -12,7 +12,7 @@ myst_nb>=1.0.0 numba numpyro optax -oryx @ git+https://github.com/jax-ml/oryx.git@main # remove after oryx release +oryx pymc scikit-learn sphinx