diff --git a/blackjax/adaptation/mclmc_adaptation.py b/blackjax/adaptation/mclmc_adaptation.py index b0e3e4811..084a750ff 100644 --- a/blackjax/adaptation/mclmc_adaptation.py +++ b/blackjax/adaptation/mclmc_adaptation.py @@ -108,7 +108,7 @@ def mclmc_find_L_and_step_size( )(state, params, num_steps, part1_key) if frac_tune3 != 0: - state, params = make_adaptation_L(mclmc_kernel, frac=frac_tune3, Lfactor=0.4)( + state, params = make_adaptation_L(mclmc_kernel(params.std_mat), frac=frac_tune3, Lfactor=0.4)( state, params, num_steps, part2_key ) @@ -129,14 +129,14 @@ def make_L_step_size_adaptation( decay_rate = (num_effective_samples - 1.0) / (num_effective_samples + 1.0) - def predictor(previous_state, params, adaptive_state, rng_key): + def predictor(previous_state, params, adaptive_state, rng_key, std_mat): """does one step with the dynamics and updates the prediction for the optimal stepsize Designed for the unadjusted MCHMC""" time, x_average, step_size_max = adaptive_state # dynamics - next_state, info = kernel( + next_state, info = kernel(std_mat)( rng_key=rng_key, state=previous_state, L=params.L, @@ -176,26 +176,30 @@ def predictor(previous_state, params, adaptive_state, rng_key): return state, params_new, adaptive_state, success - def step(iteration_state, weight_and_key): - """does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize""" + def make_step(std_mat): + + def step(iteration_state, weight_and_key): + """does one step of the dynamics and updates the estimate of the posterior size and optimal stepsize""" - mask, rng_key = weight_and_key - state, params, adaptive_state, streaming_avg = iteration_state + mask, rng_key = weight_and_key + state, params, adaptive_state, streaming_avg = iteration_state - state, params, adaptive_state, success = predictor( - state, params, adaptive_state, rng_key - ) + state, params, adaptive_state, success = predictor( + state, params, adaptive_state, rng_key, std_mat + ) - # update the running average of x, x^2 - streaming_avg = streaming_average( - O=lambda x: jnp.array([x, jnp.square(x)]), - x=ravel_pytree(state.position)[0], - streaming_avg=streaming_avg, - weight=(1-mask)*success*params.step_size, - zero_prevention=mask, - ) + # update the running average of x, x^2 + streaming_avg = streaming_average( + O=lambda x: jnp.array([x, jnp.square(x)]), + x=ravel_pytree(state.position)[0], + streaming_avg=streaming_avg, + weight=(1-mask)*success*params.step_size, + zero_prevention=mask, + ) - return (state, params, adaptive_state, streaming_avg), None + return (state, params, adaptive_state, streaming_avg), None + + return step def L_step_size_adaptation(state, params, num_steps, rng_key): num_steps1, num_steps2 = int(num_steps * frac_tune1), int( @@ -211,7 +215,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): # run the steps state, params, _, (_, average) = jax.lax.scan( - step, + make_step(1.), init=( state, params, @@ -240,7 +244,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key): steps = num_steps2 // 3 #we do some small number of steps keys = jax.random.split(final_key, steps) state, params, _, (_, average) = jax.lax.scan( - step, + make_step(std_mat), init=( state, params, diff --git a/blackjax/benchmarks/mcmc/explore.py b/blackjax/benchmarks/mcmc/explore.py index 94c89ba45..09be2ba7b 100644 --- a/blackjax/benchmarks/mcmc/explore.py +++ b/blackjax/benchmarks/mcmc/explore.py @@ -116,17 +116,30 @@ def run_mclmc_with_tuning(logdensity_fn, num_steps, initial_position, key, trans # key=sample_key, # transform=lambda x: x.position[:2], # ) -# print(samples.var(axis=0)) -m = IllConditionedGaussian(2, 5) +# # print(samples.var(axis=0)) +# m = IllConditionedGaussian(10, 5) +# sampler = lambda logdensity_fn, num_steps, initial_position, key: run_mclmc(logdensity_fn=logdensity_fn, num_steps=num_steps, initial_position=initial_position, key=key, transform=lambda x:x.position, +# # std_mat=jnp.ones((10,)) +# std_mat=jnp.sqrt(m.E_x2) +# , L=2.6576319, step_size=3.40299) +# print(m.E_x2, "var") + +# # sampler = 'mclmc' +# # samplers[sampler] +# result, bias, _ = benchmark_chains(m, sampler, n=5000, batch=1000//m.ndims,favg=m.E_x2, fvar=m.Var_x2) + +# print(result) + + # m = StandardNormal(10) # sampler = lambda logdensity_fn, num_steps, initial_position, key: run_mclmc(logdensity_fn=logdensity_fn, num_steps=num_steps, initial_position=initial_position, key=key, transform=lambda x:x.position, -# std_mat=jnp.ones((2,)) -# # std_mat=m.E_x2 -# , L=1.7296296, step_size=0.814842) -print(m.E_x2) +# std_mat=jnp.ones((10,)) +# , L=2.6576319, step_size=3.40299) +# # print(m.E_x2, "var") -sampler = 'mclmc' +# # sampler = 'mclmc' +# # samplers[sampler] +# result, bias, _ = benchmark_chains(m, sampler, n=5000, batch=1000//m.ndims,favg=m.E_x2, fvar=m.Var_x2) -result, bias, _ = benchmark_chains(m, samplers[sampler], n=5000, batch=1000//m.ndims,favg=m.E_x2, fvar=m.Var_x2) +# print(result) -print(result) \ No newline at end of file diff --git a/blackjax/benchmarks/mcmc/sampling_algorithms.py b/blackjax/benchmarks/mcmc/sampling_algorithms.py index 677fd10d4..03130ced5 100644 --- a/blackjax/benchmarks/mcmc/sampling_algorithms.py +++ b/blackjax/benchmarks/mcmc/sampling_algorithms.py @@ -47,11 +47,11 @@ def run_mclmc(logdensity_fn, num_steps, initial_position, transform, key): ) integrator = blackjax.mcmc.integrators.isokinetic_mclachlan - - kernel = blackjax.mcmc.mclmc.build_kernel( + kernel = lambda std_mat : blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=integrator, - std_mat=jnp.ones((initial_position.shape[0],)), + # std_mat=jnp.ones((initial_position.shape[0],)), + std_mat=std_mat, ) ( @@ -62,7 +62,7 @@ def run_mclmc(logdensity_fn, num_steps, initial_position, transform, key): num_steps=num_steps, state=initial_state, rng_key=tune_key, - diagonal_preconditioning=False + diagonal_preconditioning=True ) # jax.debug.print("params {x}", x=blackjax_mclmc_sampler_params)