Skip to content

Commit

Permalink
RENAME STREAMING_AVERAGE_UPDATE ARGS IN ADJUSTED MCLMC ADAPTATION
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Jun 5, 2024
1 parent ef40045 commit 90be1be
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions blackjax/adaptation/adjusted_mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def step(iteration_state, weight_and_key):
previous_state,
params,
(adaptive_state, step_size_max),
streaming_avg,
previous_weight_and_average,
) = iteration_state

avg_num_integration_steps = params.L / params.step_size
Expand Down Expand Up @@ -174,9 +174,9 @@ def step(iteration_state, weight_and_key):

x = ravel_pytree(state.position)[0]
# update the running average of x, x^2
streaming_avg = streaming_average_update(
expectation=jnp.array([x, jnp.square(x)]),
streaming_avg=streaming_avg,
previous_weight_and_average = streaming_average_update(
current_value=jnp.array([x, jnp.square(x)]),
previous_weight_and_average=previous_weight_and_average,
weight=(1 - mask) * success * step_size,
zero_prevention=mask,
)
Expand All @@ -193,7 +193,7 @@ def step(iteration_state, weight_and_key):
+ (1 - mask) * params.L,
)

return (state, params, (adaptive_state, step_size_max), streaming_avg), (
return (state, params, (adaptive_state, step_size_max), previous_weight_and_average), (
info,
params,
)
Expand Down

0 comments on commit 90be1be

Please sign in to comment.