Skip to content

Commit

Permalink
Merge branch 'emaus' of github.com:reubenharry/blackjax into emaus
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Feb 27, 2025
2 parents 2b33625 + 5ee0b8a commit 4aec6e9
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions blackjax/adaptation/step_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,14 @@ def update(state, exp_x, acc_rate_new):
x = jnp.log(exp_x)

def on_true(bounds):
bounds0 = jnp.max(jnp.array([bounds[0], x]))
return jnp.array([bounds0, bounds[1]]), bounds0 + reduce_shift
lower, upper = bounds
lower = jnp.max(jnp.array([lower, x]))
return jnp.array([lower, upper]), lower + reduce_shift

def on_false(bounds):
bounds1 = jnp.min(jnp.array([bounds[1], x]))
return jnp.array([bounds[0], bounds1]), bounds1 - reduce_shift
lower, upper = bounds
upper = jnp.min(jnp.array([upper, x]))
return jnp.array([lower, upper]), upper - reduce_shift

bounds_new, x_new = jax.lax.cond(acc_high, on_true, on_false, bounds)

Expand Down

0 comments on commit 4aec6e9

Please sign in to comment.