Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streamline run_inference_algorithm and the streaming average #713

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jax.flatten_util import ravel_pytree

from blackjax.diagnostics import effective_sample_size
from blackjax.util import pytree_size, streaming_average_update
from blackjax.util import incremental_value_update, pytree_size


class MCLMCAdaptationState(NamedTuple):
Expand Down Expand Up @@ -199,7 +199,7 @@ def step(iteration_state, weight_and_key):

x = ravel_pytree(state.position)[0]
# update the running average of x, x^2
streaming_avg = streaming_average_update(
streaming_avg = incremental_value_update(
current_value=jnp.array([x, jnp.square(x)]),
previous_weight_and_average=streaming_avg,
weight=(1 - mask) * success * params.step_size,
Expand Down
152 changes: 93 additions & 59 deletions blackjax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from functools import partial
from typing import Callable, Union

import jax
import jax.numpy as jnp
from jax import jit, lax
from jax.flatten_util import ravel_pytree
Expand Down Expand Up @@ -149,9 +148,7 @@ def run_inference_algorithm(
initial_state: ArrayLikeTree = None,
initial_position: ArrayLikeTree = None,
progress_bar: bool = False,
transform: Callable = lambda x: x,
return_state_history=True,
expectation: Callable = lambda x: x,
transform: Callable = lambda state, info: (state, info),
) -> tuple:
"""Wrapper to run an inference algorithm.

Expand All @@ -166,104 +163,141 @@ def run_inference_algorithm(
initial_state
The initial state of the inference algorithm.
initial_position
The initial position of the inference algorithm. This is used when the initial
state is not provided.
The initial position of the inference algorithm. This is used when the initial state is not provided.
inference_algorithm
One of blackjax's sampling algorithms or variational inference algorithms.
num_steps
Number of MCMC steps.
progress_bar
Whether to display a progress bar.
transform
A transformation of the trace of states to be returned. This is useful for
A transformation of the trace of states (and info) to be returned. This is useful for
computing determinstic variables, or returning a subset of the states.
By default, the states are returned as is.
expectation
A function that computes the expectation of the state. This is done
incrementally, so doesn't require storing all the states.
return_state_history
if False, `run_inference_algorithm` will only return an expectation of the value
of transform, and return that average instead of the full set of samples. This
is useful when memory is a bottleneck.

Returns
-------
If return_state_history is True:
1. The final state.
2. The trace of the state.
2. The trace of the transform(state)
3. The trace of the info of the inference algorithm for diagnostics.
If return_state_history is False:
1. This is the expectation of state over the chain. Otherwise the final state.
2. The final state of the inference algorithm.
"""

if initial_state is None and initial_position is None:
raise ValueError(
"Either `initial_state` or `initial_position` must be provided."
)
raise ValueError("Either initial_state or initial_position must be provided.")
if initial_state is not None and initial_position is not None:
raise ValueError(
"Only one of `initial_state` or `initial_position` must be provided."
"Only one of initial_state or initial_position must be provided."
)

if initial_state is None:
rng_key, init_key = split(rng_key, 2)
rng_key, init_key = split(rng_key, 2)
if initial_position is not None:
initial_state = inference_algorithm.init(initial_position, init_key)

keys = split(rng_key, num_steps)

def one_step(average_and_state, xs, return_state):
def one_step(state, xs):
_, rng_key = xs
average, state = average_and_state
state, info = inference_algorithm.step(rng_key, state)
average = streaming_average_update(expectation(transform(state)), average)
if return_state:
return (average, state), (transform(state), info)
else:
return (average, state), None

one_step = jax.jit(partial(one_step, return_state=return_state_history))
return state, transform(state, info)

if progress_bar:
one_step = progress_bar_scan(num_steps)(one_step)

xs = (jnp.arange(num_steps), keys)
((_, average), final_state), history = lax.scan(
one_step, ((0, expectation(transform(initial_state))), initial_state), xs
)
xs = jnp.arange(num_steps), keys
final_state, history = lax.scan(one_step, initial_state, xs)
return final_state, history

if not return_state_history:
return average, transform(final_state)
else:
state_history, info_history = history
return transform(final_state), state_history, info_history

def store_only_expectation_values(
sampling_algorithm,
state_transform=lambda x: x,
incremental_value_transform=lambda x: x,
):
"""Takes a sampling algorithm and constructs from it a new sampling algorithm object. The new sampling algorithm has the same
kernel but only stores the streaming expectation values of some observables, not the full states; to save memory.

It saves incremental_value_transform(E[state_transform(x)]) at each step i, where expectation is computed with samples up to i-th sample.

Example:

.. code::

init_key, state_key, run_key = jax.random.split(jax.random.PRNGKey(0),3)
model = StandardNormal(2)
initial_position = model.sample_init(init_key)
initial_state = blackjax.mcmc.mclmc.init(
position=initial_position, logdensity_fn=model.logdensity_fn, rng_key=state_key
)
integrator_type = "mclachlan"
L = 1.0
step_size = 0.1
num_steps = 4

integrator = map_integrator_type_to_integrator['mclmc'][integrator_type]
state_transform = lambda state: state.position
memory_efficient_sampling_alg, transform = store_only_expectation_values(
sampling_algorithm=sampling_alg,
state_transform=state_transform)

initial_state = memory_efficient_sampling_alg.init(initial_state)

final_state, trace_at_every_step = run_inference_algorithm(

rng_key=run_key,
initial_state=initial_state,
inference_algorithm=memory_efficient_sampling_alg,
num_steps=num_steps,
transform=transform,
progress_bar=True,
)
"""

def streaming_average_update(
current_value, previous_weight_and_average, weight=1.0, zero_prevention=0.0
def init_fn(state):
averaging_state = (0.0, state_transform(state))
return (state, averaging_state)

def update_fn(rng_key, state_and_incremental_val):
state, averaging_state = state_and_incremental_val
state, info = sampling_algorithm.step(
rng_key, state
) # update the state with the sampling algorithm
averaging_state = incremental_value_update(
state_transform(state), averaging_state
) # update the expectation value with the running average
return (state, averaging_state), info

def transform(state_and_incremental_val, info):
(state, (_, incremental_value)) = state_and_incremental_val
return incremental_value_transform(incremental_value), info

return SamplingAlgorithm(init_fn, update_fn), transform


def incremental_value_update(
expectation, incremental_val, weight=1.0, zero_prevention=0.0
):
"""Compute the streaming average of a function O(x) using a weight.
Parameters:
----------
current_value
the current value of the function that we want to take average of
previous_weight_and_average
tuple of (previous_weight, previous_average) where previous_weight is the
sum of weights and average is the current estimated average
expectation
the value of the expectation at the current timestep
incremental_val
tuple of (total, average) where total is the sum of weights and average is the current average
weight
weight of the current state
zero_prevention
small value to prevent division by zero
Returns:
----------
new total weight and streaming average
new streaming average
"""
previous_weight, previous_average = previous_weight_and_average
current_weight = previous_weight + weight
current_average = jax.tree.map(
lambda x, avg: (previous_weight * avg + weight * x)
/ (current_weight + zero_prevention),
current_value,
previous_average,

flat_expectation, unravel_fn = ravel_pytree(expectation)
total, average = incremental_val
flat_average, _ = ravel_pytree(average)
average = (total * flat_average + weight * flat_expectation) / (
total + weight + zero_prevention
)
return current_weight, current_average
total += weight
incremental_val = (total, unravel_fn(average))
return incremental_val
2 changes: 1 addition & 1 deletion tests/adaptation/test_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_chees_adaptation(adaptation_filters):
algorithm = blackjax.dynamic_hmc(logprob_fn, **parameters)

chain_keys = jax.random.split(inference_key, num_chains)
_, _, infos = jax.vmap(
_, (_, infos) = jax.vmap(
lambda key, state: run_inference_algorithm(
rng_key=key,
initial_state=state,
Expand Down
28 changes: 18 additions & 10 deletions tests/mcmc/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,12 @@ def run_mclmc(
sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov,
)

_, samples, _ = run_inference_algorithm(
_, samples = run_inference_algorithm(
rng_key=run_key,
initial_state=blackjax_state_after_tuning,
inference_algorithm=sampling_alg,
num_steps=num_steps,
transform=lambda x: x.position,
transform=lambda state, info: state.position,
)

return samples
Expand Down Expand Up @@ -223,10 +223,11 @@ def test_mala(self):

mala = blackjax.mala(logposterior_fn, 1e-5)
state = mala.init({"coefs": 1.0, "log_scale": 1.0})
_, states, _ = run_inference_algorithm(
_, states = run_inference_algorithm(
rng_key=inference_key,
initial_state=state,
inference_algorithm=mala,
transform=lambda state, info: state.position,
num_steps=10_000,
)

Expand Down Expand Up @@ -375,11 +376,12 @@ def test_pathfinder_adaptation(
)
inference_algorithm = algorithm(logposterior_fn, **parameters)

_, states, _ = run_inference_algorithm(
_, states = run_inference_algorithm(
rng_key=inference_key,
initial_state=state,
inference_algorithm=inference_algorithm,
num_steps=num_sampling_steps,
transform=lambda state, info: state.position,
)

coefs_samples = states.position["coefs"]
Expand Down Expand Up @@ -418,11 +420,12 @@ def test_meads(self):
inference_algorithm = blackjax.ghmc(logposterior_fn, **parameters)

chain_keys = jax.random.split(inference_key, num_chains)
_, states, _ = jax.vmap(
_, states = jax.vmap(
lambda key, state: run_inference_algorithm(
rng_key=key,
initial_state=state,
inference_algorithm=inference_algorithm,
transform=lambda state, info: state.position,
num_steps=100,
)
)(chain_keys, last_states)
Expand Down Expand Up @@ -465,11 +468,12 @@ def test_chees(self, jitter_generator):
inference_algorithm = blackjax.dynamic_hmc(logposterior_fn, **parameters)

chain_keys = jax.random.split(inference_key, num_chains)
_, states, _ = jax.vmap(
_, states = jax.vmap(
lambda key, state: run_inference_algorithm(
rng_key=key,
initial_state=state,
inference_algorithm=inference_algorithm,
transform=lambda state, info: state.position,
num_steps=100,
)
)(chain_keys, last_states)
Expand All @@ -494,10 +498,11 @@ def test_barker(self):
barker = blackjax.barker_proposal(logposterior_fn, 1e-1)
state = barker.init({"coefs": 1.0, "log_scale": 1.0})

_, states, _ = run_inference_algorithm(
_, states = run_inference_algorithm(
rng_key=inference_key,
initial_state=state,
inference_algorithm=barker,
transform=lambda state, info: state.position,
num_steps=10_000,
)

Expand Down Expand Up @@ -679,10 +684,11 @@ def test_latent_gaussian(self):

initial_state = inference_algorithm.init(jnp.zeros((1,)))

_, states, _ = self.variant(
_, states = self.variant(
functools.partial(
run_inference_algorithm,
inference_algorithm=inference_algorithm,
transform=lambda state, info: state.position,
num_steps=self.sampling_steps,
),
)(rng_key=self.key, initial_state=initial_state)
Expand Down Expand Up @@ -724,10 +730,11 @@ def univariate_normal_test_case(
**kwargs,
):
inference_key, orbit_key = jax.random.split(rng_key)
_, states, _ = self.variant(
_, states = self.variant(
functools.partial(
run_inference_algorithm,
inference_algorithm=inference_algorithm,
transform=lambda state, info: state.position,
num_steps=num_sampling_steps,
**kwargs,
)
Expand Down Expand Up @@ -997,10 +1004,11 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal):
functools.partial(
run_inference_algorithm,
inference_algorithm=inference_algorithm,
transform=lambda state, info: state.position,
num_steps=2_000,
)
)
_, states, _ = inference_loop_multiple_chains(
_, states = inference_loop_multiple_chains(
rng_key=multi_chain_sample_key, initial_state=initial_states
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def run_regression(algorithm, **parameters):
)
inference_algorithm = algorithm(logdensity_fn, **parameters)

_, states, _ = run_inference_algorithm(
_, (states, _) = run_inference_algorithm(
rng_key=inference_key,
initial_state=state,
inference_algorithm=inference_algorithm,
Expand Down
Loading
Loading