From 5ee0b8a188e131643316249fdaa4c967d35bb31e Mon Sep 17 00:00:00 2001 From: Reuben Date: Thu, 27 Feb 2025 07:03:09 -0500 Subject: [PATCH] Update blackjax/adaptation/step_size.py Co-authored-by: Junpeng Lao --- blackjax/adaptation/step_size.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/blackjax/adaptation/step_size.py b/blackjax/adaptation/step_size.py index 94c634ce3..61ceed592 100644 --- a/blackjax/adaptation/step_size.py +++ b/blackjax/adaptation/step_size.py @@ -270,12 +270,14 @@ def update(state, exp_x, acc_rate_new): x = jnp.log(exp_x) def on_true(bounds): - bounds0 = jnp.max(jnp.array([bounds[0], x])) - return jnp.array([bounds0, bounds[1]]), bounds0 + reduce_shift + lower, upper = bounds + lower = jnp.max(jnp.array([lower, x])) + return jnp.array([lower, upper]), lower + reduce_shift def on_false(bounds): - bounds1 = jnp.min(jnp.array([bounds[1], x])) - return jnp.array([bounds[0], bounds1]), bounds1 - reduce_shift + lower, upper = bounds + upper = jnp.min(jnp.array([upper, x])) + return jnp.array([lower, upper]), upper - reduce_shift bounds_new, x_new = jax.lax.cond(acc_high, on_true, on_false, bounds)