diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 76a016242..312b4aeea 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -20,7 +20,7 @@ from jax.flatten_util import ravel_pytree from blackjax.diagnostics import effective_sample_size -from blackjax.util import pytree_size, streaming_average_update +from blackjax.util import incremental_value_update, pytree_size class MCLMCAdaptationState(NamedTuple): @@ -199,7 +199,7 @@ def step(iteration_state, weight_and_key): x = ravel_pytree(state.position)[0] # update the running average of x, x^2 - streaming_avg = streaming_average_update( + streaming_avg = incremental_value_update( current_value=jnp.array([x, jnp.square(x)]), previous_weight_and_average=streaming_avg, weight=(1 - mask) * success * params.step_size, diff --git a/blackjax/util.py b/blackjax/util.py index cdb9f4c91..d761189cf 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -3,7 +3,6 @@ from functools import partial from typing import Callable, Union -import jax import jax.numpy as jnp from jax import jit, lax from jax.flatten_util import ravel_pytree @@ -149,9 +148,7 @@ def run_inference_algorithm( initial_state: ArrayLikeTree = None, initial_position: ArrayLikeTree = None, progress_bar: bool = False, - transform: Callable = lambda x: x, - return_state_history=True, - expectation: Callable = lambda x: x, + transform: Callable = lambda state, info: (state, info), ) -> tuple: """Wrapper to run an inference algorithm. @@ -166,8 +163,7 @@ def run_inference_algorithm( initial_state The initial state of the inference algorithm. initial_position - The initial position of the inference algorithm. This is used when the initial - state is not provided. + The initial position of the inference algorithm. This is used when the initial state is not provided. inference_algorithm One of blackjax's sampling algorithms or variational inference algorithms. num_steps @@ -175,95 +171,133 @@ def run_inference_algorithm( progress_bar Whether to display a progress bar. transform - A transformation of the trace of states to be returned. This is useful for + A transformation of the trace of states (and info) to be returned. This is useful for computing determinstic variables, or returning a subset of the states. By default, the states are returned as is. - expectation - A function that computes the expectation of the state. This is done - incrementally, so doesn't require storing all the states. - return_state_history - if False, `run_inference_algorithm` will only return an expectation of the value - of transform, and return that average instead of the full set of samples. This - is useful when memory is a bottleneck. Returns ------- - If return_state_history is True: 1. The final state. - 2. The trace of the state. + 2. The trace of the transform(state) 3. The trace of the info of the inference algorithm for diagnostics. - If return_state_history is False: - 1. This is the expectation of state over the chain. Otherwise the final state. - 2. The final state of the inference algorithm. """ if initial_state is None and initial_position is None: - raise ValueError( - "Either `initial_state` or `initial_position` must be provided." - ) + raise ValueError("Either initial_state or initial_position must be provided.") if initial_state is not None and initial_position is not None: raise ValueError( - "Only one of `initial_state` or `initial_position` must be provided." + "Only one of initial_state or initial_position must be provided." ) - if initial_state is None: - rng_key, init_key = split(rng_key, 2) + rng_key, init_key = split(rng_key, 2) + if initial_position is not None: initial_state = inference_algorithm.init(initial_position, init_key) keys = split(rng_key, num_steps) - def one_step(average_and_state, xs, return_state): + def one_step(state, xs): _, rng_key = xs - average, state = average_and_state state, info = inference_algorithm.step(rng_key, state) - average = streaming_average_update(expectation(transform(state)), average) - if return_state: - return (average, state), (transform(state), info) - else: - return (average, state), None - - one_step = jax.jit(partial(one_step, return_state=return_state_history)) + return state, transform(state, info) if progress_bar: one_step = progress_bar_scan(num_steps)(one_step) - xs = (jnp.arange(num_steps), keys) - ((_, average), final_state), history = lax.scan( - one_step, ((0, expectation(transform(initial_state))), initial_state), xs - ) + xs = jnp.arange(num_steps), keys + final_state, history = lax.scan(one_step, initial_state, xs) + return final_state, history - if not return_state_history: - return average, transform(final_state) - else: - state_history, info_history = history - return transform(final_state), state_history, info_history +def store_only_expectation_values( + sampling_algorithm, + state_transform=lambda x: x, + incremental_value_transform=lambda x: x, +): + """Takes a sampling algorithm and constructs from it a new sampling algorithm object. The new sampling algorithm has the same + kernel but only stores the streaming expectation values of some observables, not the full states; to save memory. + + It saves incremental_value_transform(E[state_transform(x)]) at each step i, where expectation is computed with samples up to i-th sample. + + Example: + + .. code:: + + init_key, state_key, run_key = jax.random.split(jax.random.PRNGKey(0),3) + model = StandardNormal(2) + initial_position = model.sample_init(init_key) + initial_state = blackjax.mcmc.mclmc.init( + position=initial_position, logdensity_fn=model.logdensity_fn, rng_key=state_key + ) + integrator_type = "mclachlan" + L = 1.0 + step_size = 0.1 + num_steps = 4 + + integrator = map_integrator_type_to_integrator['mclmc'][integrator_type] + state_transform = lambda state: state.position + memory_efficient_sampling_alg, transform = store_only_expectation_values( + sampling_algorithm=sampling_alg, + state_transform=state_transform) + + initial_state = memory_efficient_sampling_alg.init(initial_state) + + final_state, trace_at_every_step = run_inference_algorithm( + + rng_key=run_key, + initial_state=initial_state, + inference_algorithm=memory_efficient_sampling_alg, + num_steps=num_steps, + transform=transform, + progress_bar=True, + ) + """ -def streaming_average_update( - current_value, previous_weight_and_average, weight=1.0, zero_prevention=0.0 + def init_fn(state): + averaging_state = (0.0, state_transform(state)) + return (state, averaging_state) + + def update_fn(rng_key, state_and_incremental_val): + state, averaging_state = state_and_incremental_val + state, info = sampling_algorithm.step( + rng_key, state + ) # update the state with the sampling algorithm + averaging_state = incremental_value_update( + state_transform(state), averaging_state + ) # update the expectation value with the running average + return (state, averaging_state), info + + def transform(state_and_incremental_val, info): + (state, (_, incremental_value)) = state_and_incremental_val + return incremental_value_transform(incremental_value), info + + return SamplingAlgorithm(init_fn, update_fn), transform + + +def incremental_value_update( + expectation, incremental_val, weight=1.0, zero_prevention=0.0 ): """Compute the streaming average of a function O(x) using a weight. Parameters: ---------- - current_value - the current value of the function that we want to take average of - previous_weight_and_average - tuple of (previous_weight, previous_average) where previous_weight is the - sum of weights and average is the current estimated average + expectation + the value of the expectation at the current timestep + incremental_val + tuple of (total, average) where total is the sum of weights and average is the current average weight weight of the current state zero_prevention small value to prevent division by zero Returns: ---------- - new total weight and streaming average + new streaming average """ - previous_weight, previous_average = previous_weight_and_average - current_weight = previous_weight + weight - current_average = jax.tree.map( - lambda x, avg: (previous_weight * avg + weight * x) - / (current_weight + zero_prevention), - current_value, - previous_average, + + flat_expectation, unravel_fn = ravel_pytree(expectation) + total, average = incremental_val + flat_average, _ = ravel_pytree(average) + average = (total * flat_average + weight * flat_expectation) / ( + total + weight + zero_prevention ) - return current_weight, current_average + total += weight + incremental_val = (total, unravel_fn(average)) + return incremental_val diff --git a/tests/adaptation/test_adaptation.py b/tests/adaptation/test_adaptation.py index 68751bee8..4b34511be 100644 --- a/tests/adaptation/test_adaptation.py +++ b/tests/adaptation/test_adaptation.py @@ -90,7 +90,7 @@ def test_chees_adaptation(adaptation_filters): algorithm = blackjax.dynamic_hmc(logprob_fn, **parameters) chain_keys = jax.random.split(inference_key, num_chains) - _, _, infos = jax.vmap( + _, (_, infos) = jax.vmap( lambda key, state: run_inference_algorithm( rng_key=key, initial_state=state, diff --git a/tests/mcmc/test_sampling.py b/tests/mcmc/test_sampling.py index 18a07625b..1471196c1 100644 --- a/tests/mcmc/test_sampling.py +++ b/tests/mcmc/test_sampling.py @@ -135,12 +135,12 @@ def run_mclmc( sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, ) - _, samples, _ = run_inference_algorithm( + _, samples = run_inference_algorithm( rng_key=run_key, initial_state=blackjax_state_after_tuning, inference_algorithm=sampling_alg, num_steps=num_steps, - transform=lambda x: x.position, + transform=lambda state, info: state.position, ) return samples @@ -223,10 +223,11 @@ def test_mala(self): mala = blackjax.mala(logposterior_fn, 1e-5) state = mala.init({"coefs": 1.0, "log_scale": 1.0}) - _, states, _ = run_inference_algorithm( + _, states = run_inference_algorithm( rng_key=inference_key, initial_state=state, inference_algorithm=mala, + transform=lambda state, info: state.position, num_steps=10_000, ) @@ -375,11 +376,12 @@ def test_pathfinder_adaptation( ) inference_algorithm = algorithm(logposterior_fn, **parameters) - _, states, _ = run_inference_algorithm( + _, states = run_inference_algorithm( rng_key=inference_key, initial_state=state, inference_algorithm=inference_algorithm, num_steps=num_sampling_steps, + transform=lambda state, info: state.position, ) coefs_samples = states.position["coefs"] @@ -418,11 +420,12 @@ def test_meads(self): inference_algorithm = blackjax.ghmc(logposterior_fn, **parameters) chain_keys = jax.random.split(inference_key, num_chains) - _, states, _ = jax.vmap( + _, states = jax.vmap( lambda key, state: run_inference_algorithm( rng_key=key, initial_state=state, inference_algorithm=inference_algorithm, + transform=lambda state, info: state.position, num_steps=100, ) )(chain_keys, last_states) @@ -465,11 +468,12 @@ def test_chees(self, jitter_generator): inference_algorithm = blackjax.dynamic_hmc(logposterior_fn, **parameters) chain_keys = jax.random.split(inference_key, num_chains) - _, states, _ = jax.vmap( + _, states = jax.vmap( lambda key, state: run_inference_algorithm( rng_key=key, initial_state=state, inference_algorithm=inference_algorithm, + transform=lambda state, info: state.position, num_steps=100, ) )(chain_keys, last_states) @@ -494,10 +498,11 @@ def test_barker(self): barker = blackjax.barker_proposal(logposterior_fn, 1e-1) state = barker.init({"coefs": 1.0, "log_scale": 1.0}) - _, states, _ = run_inference_algorithm( + _, states = run_inference_algorithm( rng_key=inference_key, initial_state=state, inference_algorithm=barker, + transform=lambda state, info: state.position, num_steps=10_000, ) @@ -679,10 +684,11 @@ def test_latent_gaussian(self): initial_state = inference_algorithm.init(jnp.zeros((1,))) - _, states, _ = self.variant( + _, states = self.variant( functools.partial( run_inference_algorithm, inference_algorithm=inference_algorithm, + transform=lambda state, info: state.position, num_steps=self.sampling_steps, ), )(rng_key=self.key, initial_state=initial_state) @@ -724,10 +730,11 @@ def univariate_normal_test_case( **kwargs, ): inference_key, orbit_key = jax.random.split(rng_key) - _, states, _ = self.variant( + _, states = self.variant( functools.partial( run_inference_algorithm, inference_algorithm=inference_algorithm, + transform=lambda state, info: state.position, num_steps=num_sampling_steps, **kwargs, ) @@ -997,10 +1004,11 @@ def test_mcse(self, algorithm, parameters, is_mass_matrix_diagonal): functools.partial( run_inference_algorithm, inference_algorithm=inference_algorithm, + transform=lambda state, info: state.position, num_steps=2_000, ) ) - _, states, _ = inference_loop_multiple_chains( + _, states = inference_loop_multiple_chains( rng_key=multi_chain_sample_key, initial_state=initial_states ) diff --git a/tests/test_benchmarks.py b/tests/test_benchmarks.py index c2295e7e2..2d108a48d 100644 --- a/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -48,7 +48,7 @@ def run_regression(algorithm, **parameters): ) inference_algorithm = algorithm(logdensity_fn, **parameters) - _, states, _ = run_inference_algorithm( + _, (states, _) = run_inference_algorithm( rng_key=inference_key, initial_state=state, inference_algorithm=inference_algorithm, diff --git a/tests/test_util.py b/tests/test_util.py index 1f03498dd..ba1fd6cbf 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -4,7 +4,7 @@ from absl.testing import absltest, parameterized import blackjax -from blackjax.util import run_inference_algorithm +from blackjax.util import run_inference_algorithm, store_only_expectation_values class RunInferenceAlgorithmTest(chex.TestCase): @@ -41,37 +41,49 @@ def logdensity_fn(x): 10, ) - init_key, run_key = jax.random.split(self.key, 2) - + init_key, state_key, run_key = jax.random.split(jax.random.PRNGKey(0), 3) initial_state = blackjax.mcmc.mclmc.init( - position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key + position=initial_position, logdensity_fn=logdensity_fn, rng_key=state_key + ) + L = 1.0 + step_size = 0.1 + num_steps = 4 + + sampling_alg = blackjax.mclmc( + logdensity_fn, + L=L, + step_size=step_size, ) - alg = blackjax.mclmc(logdensity_fn=logdensity_fn, L=0.5, step_size=0.1) + state_transform = lambda x: x.position - _, states, info = run_inference_algorithm( + _, samples = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, - inference_algorithm=alg, - num_steps=50, - progress_bar=False, - expectation=lambda x: x, - transform=lambda x: x.position, - return_state_history=True, + inference_algorithm=sampling_alg, + num_steps=num_steps, + transform=lambda state, info: state_transform(state), + progress_bar=True, + ) + + print("average of steps (slow way):", samples.mean(axis=0)) + + memory_efficient_sampling_alg, transform = store_only_expectation_values( + sampling_algorithm=sampling_alg, state_transform=state_transform ) - average, _ = run_inference_algorithm( + initial_state = memory_efficient_sampling_alg.init(initial_state) + + final_state, trace_at_every_step = run_inference_algorithm( rng_key=run_key, initial_state=initial_state, - inference_algorithm=alg, - num_steps=50, - progress_bar=False, - expectation=lambda x: x, - transform=lambda x: x.position, - return_state_history=False, + inference_algorithm=memory_efficient_sampling_alg, + num_steps=num_steps, + transform=transform, + progress_bar=True, ) - assert jnp.allclose(states.mean(axis=0), average) + assert jnp.allclose(trace_at_every_step[0][-1], samples.mean(axis=0)) @parameterized.parameters([True, False]) def test_compatible_with_initial_pos(self, progress_bar):