diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index b1b012c70..76a016242 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -200,8 +200,8 @@ def step(iteration_state, weight_and_key): x = ravel_pytree(state.position)[0] # update the running average of x, x^2 streaming_avg = streaming_average_update( - expectation=jnp.array([x, jnp.square(x)]), - streaming_avg=streaming_avg, + current_value=jnp.array([x, jnp.square(x)]), + previous_weight_and_average=streaming_avg, weight=(1 - mask) * success * params.step_size, zero_prevention=mask, ) diff --git a/blackjax/util.py b/blackjax/util.py index b9b7a250c..cdb9f4c91 100644 --- a/blackjax/util.py +++ b/blackjax/util.py @@ -240,31 +240,30 @@ def one_step(average_and_state, xs, return_state): def streaming_average_update( - expectation, streaming_avg, weight=1.0, zero_prevention=0.0 + current_value, previous_weight_and_average, weight=1.0, zero_prevention=0.0 ): """Compute the streaming average of a function O(x) using a weight. Parameters: ---------- - expectation - the value of the expectation at the current timestep - streaming_avg - tuple of (total, average) where total is the sum of weights and average is - the current average + current_value + the current value of the function that we want to take average of + previous_weight_and_average + tuple of (previous_weight, previous_average) where previous_weight is the + sum of weights and average is the current estimated average weight weight of the current state zero_prevention small value to prevent division by zero Returns: ---------- - new streaming average + new total weight and streaming average """ - - flat_expectation, unravel_fn = ravel_pytree(expectation) - total, average = streaming_avg - flat_average, _ = ravel_pytree(average) - average = (total * flat_average + weight * flat_expectation) / ( - total + weight + zero_prevention + previous_weight, previous_average = previous_weight_and_average + current_weight = previous_weight + weight + current_average = jax.tree.map( + lambda x, avg: (previous_weight * avg + weight * x) + / (current_weight + zero_prevention), + current_value, + previous_average, ) - total += weight - streaming_avg = (total, unravel_fn(average)) - return streaming_avg + return current_weight, current_average