From 61ec765f66e7aeb8c0bbce10e1c7ff106ac4194b Mon Sep 17 00:00:00 2001 From: = Date: Mon, 11 Mar 2024 17:19:06 +0100 Subject: [PATCH 1/2] MERGE --- blackjax/benchmarks/mcmc/explore.py | 33 ++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/blackjax/benchmarks/mcmc/explore.py b/blackjax/benchmarks/mcmc/explore.py index 94c89ba45..47f33a7bb 100644 --- a/blackjax/benchmarks/mcmc/explore.py +++ b/blackjax/benchmarks/mcmc/explore.py @@ -117,16 +117,29 @@ def run_mclmc_with_tuning(logdensity_fn, num_steps, initial_position, key, trans # transform=lambda x: x.position[:2], # ) # print(samples.var(axis=0)) -m = IllConditionedGaussian(2, 5) -# m = StandardNormal(10) -# sampler = lambda logdensity_fn, num_steps, initial_position, key: run_mclmc(logdensity_fn=logdensity_fn, num_steps=num_steps, initial_position=initial_position, key=key, transform=lambda x:x.position, -# std_mat=jnp.ones((2,)) -# # std_mat=m.E_x2 -# , L=1.7296296, step_size=0.814842) -print(m.E_x2) +m = IllConditionedGaussian(10, 5) +sampler = lambda logdensity_fn, num_steps, initial_position, key: run_mclmc(logdensity_fn=logdensity_fn, num_steps=num_steps, initial_position=initial_position, key=key, transform=lambda x:x.position, + # std_mat=jnp.ones((10,)) + std_mat=jnp.sqrt(m.E_x2) + , L=2.6576319, step_size=3.40299) +print(m.E_x2, "var") -sampler = 'mclmc' +# sampler = 'mclmc' +# samplers[sampler] +result, bias, _ = benchmark_chains(m, sampler, n=5000, batch=1000//m.ndims,favg=m.E_x2, fvar=m.Var_x2) -result, bias, _ = benchmark_chains(m, samplers[sampler], n=5000, batch=1000//m.ndims,favg=m.E_x2, fvar=m.Var_x2) +print(result) + + +m = StandardNormal(10) +sampler = lambda logdensity_fn, num_steps, initial_position, key: run_mclmc(logdensity_fn=logdensity_fn, num_steps=num_steps, initial_position=initial_position, key=key, transform=lambda x:x.position, + std_mat=jnp.ones((10,)) + , L=2.6576319, step_size=3.40299) +# print(m.E_x2, "var") + +# sampler = 'mclmc' +# samplers[sampler] +result, bias, _ = benchmark_chains(m, sampler, n=5000, batch=1000//m.ndims,favg=m.E_x2, fvar=m.Var_x2) + +print(result) -print(result) \ No newline at end of file From 30d031c7e21640f5acc65b6414a14709483f57a7 Mon Sep 17 00:00:00 2001 From: = Date: Wed, 13 Mar 2024 22:28:29 +0100 Subject: [PATCH 2/2] FIX SOME ISSUES WITH PRECONDITIONING --- blackjax/adaptation/mclmc_adaptation.py | 46 ++++++++++--------- blackjax/benchmarks/mcmc/explore.py | 40 ++++++++-------- .../benchmarks/mcmc/sampling_algorithms.py | 5 +- 3 files changed, 48 insertions(+), 43 deletions(-) diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index 4e15dd454..e7285f0b0 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -132,7 +132,7 @@ def kernel(x): )(state, params, num_steps, part1_key) if frac_tune3 != 0: - state, params = make_adaptation_L(mclmc_kernel, frac=frac_tune3, Lfactor=0.4)( + state, params = make_adaptation_L(mclmc_kernel(params.std_mat), frac=frac_tune3, Lfactor=0.4)( state, params, num_steps, part2_key ) @@ -153,14 +153,14 @@ def make_L_step_size_adaptation( decay_rate = (num_effective_samples - 1.0) / (num_effective_samples + 1.0) - def predictor(previous_state, params, adaptive_state, rng_key): + def predictor(previous_state, params, adaptive_state, rng_key, std_mat): """does one step with the dynamics and updates the prediction for the optimal stepsize Designed for the unadjusted MCHMC""" time, x_average, step_size_max = adaptive_state # dynamics - next_state, info = kernel( + next_state, info = kernel(std_mat)( rng_key=rng_key, state=previous_state, L=params.L, @@ -200,26 +200,30 @@ def predictor(previous_state, params, adaptive_state, rng_key): return state, params_new, adaptive_state, success - def step(iteration_state, weight_and_key): - """does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize""" + def make_step(std_mat): + + def step(iteration_state, weight_and_key): + """does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize""" - mask, rng_key = weight_and_key - state, params, adaptive_state, streaming_avg = iteration_state + mask, rng_key = weight_and_key + state, params, adaptive_state, streaming_avg = iteration_state - state, params, adaptive_state, success = predictor( - state, params, adaptive_state, rng_key - ) + state, params, adaptive_state, success = predictor( + state, params, adaptive_state, rng_key, std_mat + ) - # update the running average of x, x^2 - streaming_avg = streaming_average( - O=lambda x: jnp.array([x, jnp.square(x)]), - x=ravel_pytree(state.position)[0], - streaming_avg=streaming_avg, - weight=(1-mask)*success*params.step_size, - zero_prevention=mask, - ) + # update the running average of x, x^2 + streaming_avg = streaming_average( + O=lambda x: jnp.array([x, jnp.square(x)]), + x=ravel_pytree(state.position)[0], + streaming_avg=streaming_avg, + weight=(1-mask)*success*params.step_size, + zero_prevention=mask, + ) - return (state, params, adaptive_state, streaming_avg), None + return (state, params, adaptive_state, streaming_avg), None + + return step def L_step_size_adaptation(state, params, num_steps, rng_key): num_steps1, num_steps2 = int(num_steps * frac_tune1), int( @@ -235,7 +239,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): # run the steps state, params, _, (_, average) = jax.lax.scan( - step, + make_step(1.), init=( state, params, @@ -264,7 +268,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): steps = num_steps2 // 3 #we do some small number of steps keys = jax.random.split(final_key, steps) state, params, _, (_, average) = jax.lax.scan( - step, + make_step(std_mat), init=( state, params, diff --git a/blackjax/benchmarks/mcmc/explore.py b/blackjax/benchmarks/mcmc/explore.py index 47f33a7bb..09be2ba7b 100644 --- a/blackjax/benchmarks/mcmc/explore.py +++ b/blackjax/benchmarks/mcmc/explore.py @@ -116,30 +116,30 @@ def run_mclmc_with_tuning(logdensity_fn, num_steps, initial_position, key, trans # key=sample_key, # transform=lambda x: x.position[:2], # ) -# print(samples.var(axis=0)) -m = IllConditionedGaussian(10, 5) -sampler = lambda logdensity_fn, num_steps, initial_position, key: run_mclmc(logdensity_fn=logdensity_fn, num_steps=num_steps, initial_position=initial_position, key=key, transform=lambda x:x.position, - # std_mat=jnp.ones((10,)) - std_mat=jnp.sqrt(m.E_x2) - , L=2.6576319, step_size=3.40299) -print(m.E_x2, "var") +# # print(samples.var(axis=0)) +# m = IllConditionedGaussian(10, 5) +# sampler = lambda logdensity_fn, num_steps, initial_position, key: run_mclmc(logdensity_fn=logdensity_fn, num_steps=num_steps, initial_position=initial_position, key=key, transform=lambda x:x.position, +# # std_mat=jnp.ones((10,)) +# std_mat=jnp.sqrt(m.E_x2) +# , L=2.6576319, step_size=3.40299) +# print(m.E_x2, "var") -# sampler = 'mclmc' -# samplers[sampler] -result, bias, _ = benchmark_chains(m, sampler, n=5000, batch=1000//m.ndims,favg=m.E_x2, fvar=m.Var_x2) +# # sampler = 'mclmc' +# # samplers[sampler] +# result, bias, _ = benchmark_chains(m, sampler, n=5000, batch=1000//m.ndims,favg=m.E_x2, fvar=m.Var_x2) -print(result) +# print(result) -m = StandardNormal(10) -sampler = lambda logdensity_fn, num_steps, initial_position, key: run_mclmc(logdensity_fn=logdensity_fn, num_steps=num_steps, initial_position=initial_position, key=key, transform=lambda x:x.position, - std_mat=jnp.ones((10,)) - , L=2.6576319, step_size=3.40299) -# print(m.E_x2, "var") +# m = StandardNormal(10) +# sampler = lambda logdensity_fn, num_steps, initial_position, key: run_mclmc(logdensity_fn=logdensity_fn, num_steps=num_steps, initial_position=initial_position, key=key, transform=lambda x:x.position, +# std_mat=jnp.ones((10,)) +# , L=2.6576319, step_size=3.40299) +# # print(m.E_x2, "var") -# sampler = 'mclmc' -# samplers[sampler] -result, bias, _ = benchmark_chains(m, sampler, n=5000, batch=1000//m.ndims,favg=m.E_x2, fvar=m.Var_x2) +# # sampler = 'mclmc' +# # samplers[sampler] +# result, bias, _ = benchmark_chains(m, sampler, n=5000, batch=1000//m.ndims,favg=m.E_x2, fvar=m.Var_x2) -print(result) +# print(result) diff --git a/blackjax/benchmarks/mcmc/sampling_algorithms.py b/blackjax/benchmarks/mcmc/sampling_algorithms.py index a62cf8cb6..dcc2faea1 100644 --- a/blackjax/benchmarks/mcmc/sampling_algorithms.py +++ b/blackjax/benchmarks/mcmc/sampling_algorithms.py @@ -44,10 +44,11 @@ def run_mclmc(logdensity_fn, num_steps, initial_position, key): position=initial_position, logdensity_fn=logdensity_fn, rng_key=init_key ) - kernel = blackjax.mcmc.mclmc.build_kernel( + kernel = lambda std_mat : blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, - std_mat=jnp.ones((initial_position.shape[0],)), + # std_mat=jnp.ones((initial_position.shape[0],)), + std_mat=std_mat, ) (