Skip to content

Commit

Permalink
FIX SOME ISSUES WITH PRECONDITIONING
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Mar 13, 2024
1 parent 61ec765 commit 30d031c
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 43 deletions.
46 changes: 25 additions & 21 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def kernel(x):
)(state, params, num_steps, part1_key)

if frac_tune3 != 0:
state, params = make_adaptation_L(mclmc_kernel, frac=frac_tune3, Lfactor=0.4)(
state, params = make_adaptation_L(mclmc_kernel(params.std_mat), frac=frac_tune3, Lfactor=0.4)(
state, params, num_steps, part2_key
)

Expand All @@ -153,14 +153,14 @@ def make_L_step_size_adaptation(

decay_rate = (num_effective_samples - 1.0) / (num_effective_samples + 1.0)

def predictor(previous_state, params, adaptive_state, rng_key):
def predictor(previous_state, params, adaptive_state, rng_key, std_mat):
"""does one step with the dynamics and updates the prediction for the optimal stepsize
Designed for the unadjusted MCHMC"""

time, x_average, step_size_max = adaptive_state

# dynamics
next_state, info = kernel(
next_state, info = kernel(std_mat)(
rng_key=rng_key,
state=previous_state,
L=params.L,
Expand Down Expand Up @@ -200,26 +200,30 @@ def predictor(previous_state, params, adaptive_state, rng_key):
return state, params_new, adaptive_state, success


def step(iteration_state, weight_and_key):
"""does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize"""
def make_step(std_mat):

def step(iteration_state, weight_and_key):
"""does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize"""

mask, rng_key = weight_and_key
state, params, adaptive_state, streaming_avg = iteration_state
mask, rng_key = weight_and_key
state, params, adaptive_state, streaming_avg = iteration_state

state, params, adaptive_state, success = predictor(
state, params, adaptive_state, rng_key
)
state, params, adaptive_state, success = predictor(
state, params, adaptive_state, rng_key, std_mat
)

# update the running average of x, x^2
streaming_avg = streaming_average(
O=lambda x: jnp.array([x, jnp.square(x)]),
x=ravel_pytree(state.position)[0],
streaming_avg=streaming_avg,
weight=(1-mask)*success*params.step_size,
zero_prevention=mask,
)
# update the running average of x, x^2
streaming_avg = streaming_average(
O=lambda x: jnp.array([x, jnp.square(x)]),
x=ravel_pytree(state.position)[0],
streaming_avg=streaming_avg,
weight=(1-mask)*success*params.step_size,
zero_prevention=mask,
)

return (state, params, adaptive_state, streaming_avg), None
return (state, params, adaptive_state, streaming_avg), None

return step

def L_step_size_adaptation(state, params, num_steps, rng_key):
num_steps1, num_steps2 = int(num_steps * frac_tune1), int(
Expand All @@ -235,7 +239,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):

# run the steps
state, params, _, (_, average) = jax.lax.scan(
step,
make_step(1.),
init=(
state,
params,
Expand Down Expand Up @@ -264,7 +268,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
steps = num_steps2 // 3 #we do some small number of steps
keys = jax.random.split(final_key, steps)
state, params, _, (_, average) = jax.lax.scan(
step,
make_step(std_mat),
init=(
state,
params,
Expand Down
40 changes: 20 additions & 20 deletions blackjax/benchmarks/mcmc/explore.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,30 +116,30 @@ def run_mclmc_with_tuning(logdensity_fn, num_steps, initial_position, key, trans
# key=sample_key,
# transform=lambda x: x.position[:2],
# )
# print(samples.var(axis=0))
m = IllConditionedGaussian(10, 5)
sampler = lambda logdensity_fn, num_steps, initial_position, key: run_mclmc(logdensity_fn=logdensity_fn, num_steps=num_steps, initial_position=initial_position, key=key, transform=lambda x:x.position,
# std_mat=jnp.ones((10,))
std_mat=jnp.sqrt(m.E_x2)
, L=2.6576319, step_size=3.40299)
print(m.E_x2, "var")
# # print(samples.var(axis=0))
# m = IllConditionedGaussian(10, 5)
# sampler = lambda logdensity_fn, num_steps, initial_position, key: run_mclmc(logdensity_fn=logdensity_fn, num_steps=num_steps, initial_position=initial_position, key=key, transform=lambda x:x.position,
# # std_mat=jnp.ones((10,))
# std_mat=jnp.sqrt(m.E_x2)
# , L=2.6576319, step_size=3.40299)
# print(m.E_x2, "var")

# sampler = 'mclmc'
# samplers[sampler]
result, bias, _ = benchmark_chains(m, sampler, n=5000, batch=1000//m.ndims,favg=m.E_x2, fvar=m.Var_x2)
# # sampler = 'mclmc'
# # samplers[sampler]
# result, bias, _ = benchmark_chains(m, sampler, n=5000, batch=1000//m.ndims,favg=m.E_x2, fvar=m.Var_x2)

print(result)
# print(result)


m = StandardNormal(10)
sampler = lambda logdensity_fn, num_steps, initial_position, key: run_mclmc(logdensity_fn=logdensity_fn, num_steps=num_steps, initial_position=initial_position, key=key, transform=lambda x:x.position,
std_mat=jnp.ones((10,))
, L=2.6576319, step_size=3.40299)
# print(m.E_x2, "var")
# m = StandardNormal(10)
# sampler = lambda logdensity_fn, num_steps, initial_position, key: run_mclmc(logdensity_fn=logdensity_fn, num_steps=num_steps, initial_position=initial_position, key=key, transform=lambda x:x.position,
# std_mat=jnp.ones((10,))
# , L=2.6576319, step_size=3.40299)
# # print(m.E_x2, "var")

# sampler = 'mclmc'
# samplers[sampler]
result, bias, _ = benchmark_chains(m, sampler, n=5000, batch=1000//m.ndims,favg=m.E_x2, fvar=m.Var_x2)
# # sampler = 'mclmc'
# # samplers[sampler]
# result, bias, _ = benchmark_chains(m, sampler, n=5000, batch=1000//m.ndims,favg=m.E_x2, fvar=m.Var_x2)

print(result)
# print(result)

5 changes: 3 additions & 2 deletions blackjax/benchmarks/mcmc/sampling_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,11 @@ def run_mclmc(logdensity_fn, num_steps, initial_position, key):
position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key
)

kernel = blackjax.mcmc.mclmc.build_kernel(
kernel = lambda std_mat : blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=logdensity_fn,
integrator=blackjax.mcmc.integrators.isokinetic_mclachlan,
std_mat=jnp.ones((initial_position.shape[0],)),
# std_mat=jnp.ones((initial_position.shape[0],)),
std_mat=std_mat,
)

(
Expand Down

0 comments on commit 30d031c

Please sign in to comment.