Skip to content

Commit

Permalink
FIX PRECONDITIONING
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Mar 13, 2024
2 parents 142425c + 30d031c commit 4bc4b54
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 34 deletions.
46 changes: 25 additions & 21 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def mclmc_find_L_and_step_size(
)(state, params, num_steps, part1_key)

if frac_tune3 != 0:
state, params = make_adaptation_L(mclmc_kernel, frac=frac_tune3, Lfactor=0.4)(
state, params = make_adaptation_L(mclmc_kernel(params.std_mat), frac=frac_tune3, Lfactor=0.4)(
state, params, num_steps, part2_key
)

Expand All @@ -129,14 +129,14 @@ def make_L_step_size_adaptation(

decay_rate = (num_effective_samples - 1.0) / (num_effective_samples + 1.0)

def predictor(previous_state, params, adaptive_state, rng_key):
def predictor(previous_state, params, adaptive_state, rng_key, std_mat):
"""does one step with the dynamics and updates the prediction for the optimal stepsize
Designed for the unadjusted MCHMC"""

time, x_average, step_size_max = adaptive_state

# dynamics
next_state, info = kernel(
next_state, info = kernel(std_mat)(
rng_key=rng_key,
state=previous_state,
L=params.L,
Expand Down Expand Up @@ -176,26 +176,30 @@ def predictor(previous_state, params, adaptive_state, rng_key):
return state, params_new, adaptive_state, success


def step(iteration_state, weight_and_key):
"""does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize"""
def make_step(std_mat):

def step(iteration_state, weight_and_key):
"""does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize"""

mask, rng_key = weight_and_key
state, params, adaptive_state, streaming_avg = iteration_state
mask, rng_key = weight_and_key
state, params, adaptive_state, streaming_avg = iteration_state

state, params, adaptive_state, success = predictor(
state, params, adaptive_state, rng_key
)
state, params, adaptive_state, success = predictor(
state, params, adaptive_state, rng_key, std_mat
)

# update the running average of x, x^2
streaming_avg = streaming_average(
O=lambda x: jnp.array([x, jnp.square(x)]),
x=ravel_pytree(state.position)[0],
streaming_avg=streaming_avg,
weight=(1-mask)*success*params.step_size,
zero_prevention=mask,
)
# update the running average of x, x^2
streaming_avg = streaming_average(
O=lambda x: jnp.array([x, jnp.square(x)]),
x=ravel_pytree(state.position)[0],
streaming_avg=streaming_avg,
weight=(1-mask)*success*params.step_size,
zero_prevention=mask,
)

return (state, params, adaptive_state, streaming_avg), None
return (state, params, adaptive_state, streaming_avg), None

return step

def L_step_size_adaptation(state, params, num_steps, rng_key):
num_steps1, num_steps2 = int(num_steps * frac_tune1), int(
Expand All @@ -211,7 +215,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):

# run the steps
state, params, _, (_, average) = jax.lax.scan(
step,
make_step(1.),
init=(
state,
params,
Expand Down Expand Up @@ -240,7 +244,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
steps = num_steps2 // 3 #we do some small number of steps
keys = jax.random.split(final_key, steps)
state, params, _, (_, average) = jax.lax.scan(
step,
make_step(std_mat),
init=(
state,
params,
Expand Down
31 changes: 22 additions & 9 deletions blackjax/benchmarks/mcmc/explore.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,17 +116,30 @@ def run_mclmc_with_tuning(logdensity_fn, num_steps, initial_position, key, trans
# key=sample_key,
# transform=lambda x: x.position[:2],
# )
# print(samples.var(axis=0))
m = IllConditionedGaussian(2, 5)
# # print(samples.var(axis=0))
# m = IllConditionedGaussian(10, 5)
# sampler = lambda logdensity_fn, num_steps, initial_position, key: run_mclmc(logdensity_fn=logdensity_fn, num_steps=num_steps, initial_position=initial_position, key=key, transform=lambda x:x.position,
# # std_mat=jnp.ones((10,))
# std_mat=jnp.sqrt(m.E_x2)
# , L=2.6576319, step_size=3.40299)
# print(m.E_x2, "var")

# # sampler = 'mclmc'
# # samplers[sampler]
# result, bias, _ = benchmark_chains(m, sampler, n=5000, batch=1000//m.ndims,favg=m.E_x2, fvar=m.Var_x2)

# print(result)


# m = StandardNormal(10)
# sampler = lambda logdensity_fn, num_steps, initial_position, key: run_mclmc(logdensity_fn=logdensity_fn, num_steps=num_steps, initial_position=initial_position, key=key, transform=lambda x:x.position,
# std_mat=jnp.ones((2,))
# # std_mat=m.E_x2
# , L=1.7296296, step_size=0.814842)
print(m.E_x2)
# std_mat=jnp.ones((10,))
# , L=2.6576319, step_size=3.40299)
# # print(m.E_x2, "var")

sampler = 'mclmc'
# # sampler = 'mclmc'
# # samplers[sampler]
# result, bias, _ = benchmark_chains(m, sampler, n=5000, batch=1000//m.ndims,favg=m.E_x2, fvar=m.Var_x2)

result, bias, _ = benchmark_chains(m, samplers[sampler], n=5000, batch=1000//m.ndims,favg=m.E_x2, fvar=m.Var_x2)
# print(result)

print(result)
8 changes: 4 additions & 4 deletions blackjax/benchmarks/mcmc/sampling_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def run_mclmc(logdensity_fn, num_steps, initial_position, transform, key):
)

integrator = blackjax.mcmc.integrators.isokinetic_mclachlan

kernel = blackjax.mcmc.mclmc.build_kernel(
kernel = lambda std_mat : blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=logdensity_fn,
integrator=integrator,
std_mat=jnp.ones((initial_position.shape[0],)),
# std_mat=jnp.ones((initial_position.shape[0],)),
std_mat=std_mat,
)

(
Expand All @@ -62,7 +62,7 @@ def run_mclmc(logdensity_fn, num_steps, initial_position, transform, key):
num_steps=num_steps,
state=initial_state,
rng_key=tune_key,
diagonal_preconditioning=False
diagonal_preconditioning=True
)

# jax.debug.print("params {x}", x=blackjax_mclmc_sampler_params)
Expand Down

0 comments on commit 4bc4b54

Please sign in to comment.