diff --git a/blackjax/adaptation/adjusted_mclmc_adaptation.py b/blackjax/adaptation/adjusted_mclmc_adaptation.py index 5003aa523..e4a6f1d5b 100644 --- a/blackjax/adaptation/adjusted_mclmc_adaptation.py +++ b/blackjax/adaptation/adjusted_mclmc_adaptation.py @@ -131,7 +131,7 @@ def step(iteration_state, weight_and_key): previous_state, params, (adaptive_state, step_size_max), - streaming_avg, + previous_weight_and_average, ) = iteration_state avg_num_integration_steps = params.L / params.step_size @@ -174,9 +174,9 @@ def step(iteration_state, weight_and_key): x = ravel_pytree(state.position)[0] # update the running average of x, x^2 - streaming_avg = streaming_average_update( - expectation=jnp.array([x, jnp.square(x)]), - streaming_avg=streaming_avg, + previous_weight_and_average = streaming_average_update( + current_value=jnp.array([x, jnp.square(x)]), + previous_weight_and_average=previous_weight_and_average, weight=(1 - mask) * success * step_size, zero_prevention=mask, ) @@ -193,7 +193,7 @@ def step(iteration_state, weight_and_key): + (1 - mask) * params.L, ) - return (state, params, (adaptive_state, step_size_max), streaming_avg), ( + return (state, params, (adaptive_state, step_size_max), previous_weight_and_average), ( info, params, )