Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor formatting #685

Merged
merged 4 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions blackjax/mcmc/integrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,11 @@ def euclidean_integrator(
yoshida_coefficients = [b1, a1, b2, a2, b2, a1, b1]
yoshida = generate_euclidean_integrator(yoshida_coefficients)

"""11 stage Omelyan integrator [I.P. Omelyan, I.M. Mryglod and R. Folk, Comput. Phys. Commun. 151 (2003) 272.],
4MN5FV in [Takaishi, Tetsuya, and Philippe De Forcrand. "Testing and tuning symplectic integrators for the hybrid Monte Carlo algorithm in lattice QCD." Physical Review E 73.3 (2006): 036706.]
popular in LQCD"""
"""
Eleven-stage palindromic symplectic integrator derived in :cite:p:`omelyan2003symplectic`.

Popular in LQCD, see also :cite:p:`takaishi2006testing`.
"""
b1 = 0.08398315262876693
a1 = 0.2539785108410595
b2 = 0.6822365335719091
Expand Down
24 changes: 16 additions & 8 deletions blackjax/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utility functions for BlackJax."""

from functools import partial
from typing import Callable, Union

Expand Down Expand Up @@ -165,7 +166,8 @@ def run_inference_algorithm(
initial_state
The initial state of the inference algorithm.
initial_position
The initial position of the inference algorithm. This is used when the initial state is not provided.
The initial position of the inference algorithm. This is used when the initial
state is not provided.
inference_algorithm
One of blackjax's sampling algorithms or variational inference algorithms.
num_steps
Expand All @@ -177,9 +179,12 @@ def run_inference_algorithm(
computing determinstic variables, or returning a subset of the states.
By default, the states are returned as is.
expectation
A function that computes the expectation of the state. This is done incrementally, so doesn't require storing all the states.
A function that computes the expectation of the state. This is done
incrementally, so doesn't require storing all the states.
return_state_history
if False, `run_inference_algorithm` will only return an expectation of the value of transform, and return that average instead of the full set of samples. This is useful when memory is a bottleneck.
if False, `run_inference_algorithm` will only return an expectation of the value
of transform, and return that average instead of the full set of samples. This
is useful when memory is a bottleneck.

Returns
-------
Expand All @@ -193,14 +198,16 @@ def run_inference_algorithm(
"""

if initial_state is None and initial_position is None:
raise ValueError("Either initial_state or initial_position must be provided.")
raise ValueError(
"Either `initial_state` or `initial_position` must be provided."
)
if initial_state is not None and initial_position is not None:
raise ValueError(
"Only one of initial_state or initial_position must be provided."
"Only one of `initial_state` or `initial_position` must be provided."
)

rng_key, init_key = split(rng_key, 2)
if initial_position is not None:
if initial_state is None:
rng_key, init_key = split(rng_key, 2)
initial_state = inference_algorithm.init(initial_position, init_key)

keys = split(rng_key, num_steps)
Expand Down Expand Up @@ -241,7 +248,8 @@ def streaming_average_update(
expectation
the value of the expectation at the current timestep
streaming_avg
tuple of (total, average) where total is the sum of weights and average is the current average
tuple of (total, average) where total is the sum of weights and average is
the current average
weight
weight of the current state
zero_prevention
Expand Down
22 changes: 22 additions & 0 deletions docs/refs.bib
Original file line number Diff line number Diff line change
Expand Up @@ -423,3 +423,25 @@ @misc{huang2021schrodingerfollmer
archivePrefix={arXiv},
primaryClass={stat.CO}
}

@article{omelyan2003symplectic,
title={Symplectic analytically integrable decomposition algorithms: classification, derivation, and application to molecular dynamics, quantum and celestial mechanics simulations},
author={Omelyan, IP and Mryglod, IM and Folk, R},
journal={Computer Physics Communications},
volume={151},
number={3},
pages={272--314},
year={2003},
publisher={Elsevier}
}

@article{takaishi2006testing,
title={Testing and tuning symplectic integrators for the hybrid Monte Carlo algorithm in lattice QCD},
author={Takaishi, Tetsuya and De Forcrand, Philippe},
journal={Physical Review E},
volume={73},
number={3},
pages={036706},
year={2006},
publisher={APS}
}
2 changes: 1 addition & 1 deletion tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,7 @@ def test_random_walk(self):

@chex.all_variants(with_pmap=False)
def test_mala(self):
inference_algorithm = blackjax.mala(self.normal_logprob, step_size=1e-1)
inference_algorithm = blackjax.mala(self.normal_logprob, step_size=0.2)
initial_state = inference_algorithm.init(jnp.array(1.0))
self.univariate_normal_test_case(
inference_algorithm, self.key, initial_state, 45000, 5_000
Expand Down
Loading