Skip to content

Commit

Permalink
new integrator
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Mar 29, 2024
1 parent d711ac6 commit 3554d61
Show file tree
Hide file tree
Showing 9 changed files with 195 additions and 40 deletions.
24 changes: 22 additions & 2 deletions blackjax/adaptation/mclmc_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from blackjax.adaptation.step_size import DualAveragingAdaptationState, dual_averaging_adaptation

from blackjax.diagnostics import effective_sample_size
from blackjax.mcmc.integrators import isokinetic_leapfrog, isokinetic_mclachlan, isokinetic_omelyan, isokinetic_yoshida, mclachlan_coefficients, omelyan_coefficients, velocity_verlet_coefficients, yoshida_coefficients
from blackjax.mcmc.mhmclmc import rescale
from blackjax.util import pytree_size

Expand All @@ -41,6 +42,19 @@ class MCLMCAdaptationState(NamedTuple):
step_size: float
std_mat : float

def integrator_order(c):
if c==velocity_verlet_coefficients: return 2
if c==mclachlan_coefficients: return 2
if c==yoshida_coefficients: return 4
if c==omelyan_coefficients: return 4


else: raise Exception(c)



target_acceptance_rate_of_order = {2 : 0.65, 4: 0.8}

def streaming_average(O, x, streaming_avg, weight, zero_prevention):
"""streaming average of f(x)"""
total, average = streaming_avg
Expand Down Expand Up @@ -323,6 +337,7 @@ def mhmclmc_find_L_and_step_size(
num_steps,
state,
rng_key,
target,
frac_tune1=0.1,
frac_tune2=0.1,
frac_tune3=0.1,
Expand All @@ -341,6 +356,8 @@ def mhmclmc_find_L_and_step_size(
The initial state of the MCMC algorithm.
rng_key
The random number generator key.
target
The target acceptance rate for the step size adaptation.
frac_tune1
The fraction of tuning for the first step of the adaptation.
frac_tune2
Expand Down Expand Up @@ -372,6 +389,7 @@ def mhmclmc_find_L_and_step_size(
dim=dim,
frac_tune1=frac_tune1,
frac_tune2=frac_tune2,
target=target
)(state, params, num_steps, part1_key)

if frac_tune3 != 0:
Expand All @@ -387,6 +405,7 @@ def mhmclmc_find_L_and_step_size(
dim=dim,
frac_tune1=frac_tune1,
frac_tune2=0,
target=target,
fix_L_first_da=True,
)(state, params, num_steps, part2_key2)

Expand All @@ -398,6 +417,7 @@ def mhmclmc_make_L_step_size_adaptation(
dim,
frac_tune1,
frac_tune2,
target,
fix_L_first_da=False,
):
"""Adapts the stepsize and L of the MCLMC kernel. Designed for the unadjusted MCLMC"""
Expand Down Expand Up @@ -538,7 +558,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):
# determine which steps to ignore in the streaming average
mask = 1 - jnp.concatenate((jnp.zeros(num_steps1), jnp.ones(num_steps2)))

initial_da, update_da, final_da = dual_averaging_adaptation(target=0.65)
initial_da, update_da, final_da = dual_averaging_adaptation(target=target)

((state, params, (dual_avg_state, step_size_max), (_, average)), info) = step_size_adaptation(mask, state, params, L_step_size_adaptation_keys_pass1, fix_L=fix_L_first_da, initial_da=initial_da, update_da=update_da)
params = params._replace(step_size=final_da(dual_avg_state)) # TODO: put back
Expand Down Expand Up @@ -570,7 +590,7 @@ def L_step_size_adaptation(state, params, num_steps, rng_key):


# jax.debug.print("{x} params before second round",x=(params))
initial_da, update_da, final_da = dual_averaging_adaptation(target=0.65)
initial_da, update_da, final_da = dual_averaging_adaptation(target=target)
((state, params, (dual_avg_state, step_size_max), (_, average)), info) = step_size_adaptation(mask, state, params, L_step_size_adaptation_keys_pass2, fix_L=True, update_da=update_da, initial_da=initial_da)
params = params._replace(step_size=final_da(dual_avg_state))
# jax.debug.print("{x}",x=("mean acceptance rate", jnp.mean(info.acceptance_rate,)))
Expand Down
3 changes: 2 additions & 1 deletion blackjax/adaptation/window_adaptation.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def window_adaptation(
initial_step_size: float = 1.0,
target_acceptance_rate: float = 0.80,
progress_bar: bool = False,
integrator = mcmc.integrators.velocity_verlet,
**extra_parameters,
) -> AdaptationAlgorithm:
"""Adapt the value of the inverse mass matrix and step size parameters of
Expand Down Expand Up @@ -289,7 +290,7 @@ def window_adaptation(
"""

mcmc_kernel = algorithm.build_kernel()
mcmc_kernel = algorithm.build_kernel(integrator)

adapt_init, adapt_step, adapt_final = base(
is_mass_matrix_diagonal,
Expand Down
55 changes: 40 additions & 15 deletions blackjax/benchmarks/mcmc/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import defaultdict
from functools import partial
import math
import os
from statistics import mean, median
Expand All @@ -14,6 +16,7 @@

from blackjax.benchmarks.mcmc.sampling_algorithms import samplers
from blackjax.benchmarks.mcmc.inference_models import models
from blackjax.mcmc.integrators import generate_euclidean_integrator, generate_isokinetic_integrator, isokinetic_mclachlan, mclachlan_coefficients, omelyan_coefficients, velocity_verlet, velocity_verlet_coefficients, yoshida_coefficients



Expand All @@ -22,7 +25,7 @@ def get_num_latents(target):
# return int(sum(map(np.prod, list(jax.tree_flatten(target.event_shape)[0]))))


def err(f_true, var_f, contract = jnp.max):
def err(f_true, var_f, contract):
"""Computes the error b^2 = (f - f_true)^2 / var_f
Args:
f: E_sampler[f(x)], can be a vector
Expand Down Expand Up @@ -83,56 +86,68 @@ def benchmark_chains(model, sampler, key, n=10000, batch=None, contract = jnp.av
init_pos = jax.vmap(model.sample_init)(init_keys)

# samples, params, avg_num_steps_per_traj = jax.pmap(lambda pos, key: sampler(model.logdensity_fn, n, pos, model.transform, key))(init_pos, keys)
samples, params, grad_calls_per_traj = jax.vmap(lambda pos, key: sampler(model.logdensity_fn, n, pos, model.transform, key))(init_pos, keys)
samples, params, grad_calls_per_traj = jax.vmap(lambda pos, key: sampler(logdensity_fn=model.logdensity_fn, num_steps=n, initial_position= pos,transform= model.transform, key=key))(init_pos, keys)
# avg_grad_calls_per_traj = jnp.mean(jnp.where(jnp.isnan(grad_calls_per_traj), 1, grad_calls_per_traj), axis=0)
avg_grad_calls_per_traj = jnp.nanmean(grad_calls_per_traj, axis=0)
print(jnp.nanmean(params.step_size,axis=0), jnp.nanmean(params.L,axis=0))

try:
print(jnp.nanmean(params.step_size,axis=0), jnp.nanmean(params.L,axis=0))
except: pass
# print("grad calls", avg_grad_calls_per_traj)

full = lambda arr : err(model.E_x2, model.Var_x2, contract)(cumulative_avg(arr))
# err_t = jnp.mean(jax.vmap(full)(samples**2), axis=0)
err_t = jax.vmap(full)(samples**2)
# print(err_t)
err_t_median = jnp.median(err_t, axis=0)
# raise Exception
# print(err_t.shape)
# foo = jax.vmap(lambda x: calculate_ess(x, grad_evals_per_step=avg_grad_calls_per_traj))(err_t)
# print(foo.shape)
outs = [calculate_ess(b, grad_evals_per_step=avg_grad_calls_per_traj) for b in err_t]
# outs = [calculate_ess(b, grad_evals_per_step=avg_grad_calls_per_traj) for b in err_t]
# print(outs[:10])
esses = [i[0].item() for i in outs if not math.isnan(i[0].item())]
grad_calls = [i[1].item() for i in outs if not math.isnan(i[1].item())]
# esses = [i[0].item() for i in outs if not math.isnan(i[0].item())]
# grad_calls = [i[1].item() for i in outs if not math.isnan(i[1].item())]
# print(grad_calls)
# raise Exception

esses, grad_calls, _ = calculate_ess(err_t_median, grad_evals_per_step=avg_grad_calls_per_traj)

# print(mean(esses), median(esses))
# print(mean(grad_calls), median(grad_calls))

# return grads_to_low_error(err_t, avg_grad_calls_per_traj)[0]
# ess_per_sample = calculate_ess(err_t, grad_evals_per_step=avg_grad_calls_per_traj)
return median(esses), median(grad_calls)
return esses, grad_calls
# , err_t[-1], params




def run_benchmarks():

results = defaultdict(tuple)

# for model, sampler in itertools.product(models, samplers):
for model, sampler in itertools.product(["Brownian Motion"], ["mhmclmc",]):
for variables in itertools.product(["Brownian Motion"], ["mhmclmc", "nuts", "mclmc", ], [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients]):

print(f"\nModel: {model}, Sampler: {sampler}\n")
model, sampler, coefficients = variables
print(f"\nModel: {model}, Sampler: {sampler}\n Coefficients: {coefficients}\n")
# sampler_to_integrator_type = {
# "mclmc": generate_isokinetic_integrator,
# "mhmclmc": generate_isokinetic_integrator,
# "nuts": generate_euclidean_integrator,
# }


results = []
Model = models[model][0]
key = jax.random.PRNGKey(2)
for i in range(1):
key1, key = jax.random.split(key)
result = benchmark_chains(Model, samplers[sampler],key1, n=models[model][1][sampler], batch=250)
# integrator = sampler_to_integrator_type[sampler](coefficients)
ess, grad_calls = benchmark_chains(Model, partial(samplers[sampler], coefficients=coefficients),key1, n=models[model][1][sampler], batch=2)
#print(f"ESS: {result.item()}")
print(f"grads to low bias: " + str(result[1:]))
results.append(result[1])
print(f"grads to low bias: {grad_calls}")
# results.append(result[1])
results[(model, sampler, tuple(coefficients))] = (ess, grad_calls)

# import matplotlib.pyplot as plt

Expand All @@ -145,6 +160,16 @@ def run_benchmarks():
# plt.title("Scatterplot of Second Element of Results")
# plt.savefig("scatterplot_mclmc.png") # Save the plot as scatterplot.png
# plt.show()

print(results)

import pandas as pd

df = pd.DataFrame(results)
df.to_csv("results.csv", index=False) # Save the DataFrame to a CSV file

if __name__ == "__main__":
run_benchmarks()

if __name__ == "__main__":
run_benchmarks()
Expand Down
10 changes: 7 additions & 3 deletions blackjax/benchmarks/mcmc/find_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import jax.numpy as jnp
from sampling_algorithms import run_mclmc, run_mhmclmc, samplers
from inference_models import Brownian, IllConditionedGaussian, models
from blackjax.adaptation.mclmc_adaptation import MCLMCAdaptationState, target_acceptance_rate_of_order, mhmclmc_integrator_order


def sampler_mhmclmc_with_tuning(step_size, L, frac_tune2, frac_tune3):

Expand All @@ -23,8 +25,9 @@ def s(logdensity_fn, num_steps, initial_position, transform, key):
initial_state = blackjax.mcmc.mhmclmc.init(
position=initial_position, logdensity_fn=logdensity_fn, random_generator_arg=init_key
)
integrator = blackjax.mcmc.integrators.isokinetic_mclachlan
kernel = lambda rng_key, state, avg_num_integration_steps, step_size: blackjax.mcmc.mhmclmc.build_kernel(
integrator=blackjax.mcmc.integrators.isokinetic_mclachlan,
integrator=integrator,
integration_steps_fn = lambda key : jnp.ceil(jax.random.uniform(key) * rescale(avg_num_integration_steps)),
# integration_steps_fn = lambda key: avg_num_integration_steps,
)(
Expand All @@ -42,6 +45,7 @@ def s(logdensity_fn, num_steps, initial_position, transform, key):
num_steps=num_steps,
state=initial_state,
rng_key=tune_key,
target=target_acceptance_rate_of_order[mhmclmc_integrator_order[integrator]],
frac_tune2=frac_tune2,
frac_tune3=frac_tune3,
params=MCLMCAdaptationState(L=L, step_size=step_size, std_mat=1.)
Expand Down Expand Up @@ -75,7 +79,7 @@ def s(logdensity_fn, num_steps, initial_position, transform, key):
print(info.acceptance_rate.mean(), "acceptance probability\n\n\n\n")
# print(out.var(axis=0), "acceptance probability")

return out, blackjax_mclmc_sampler_params, num_steps_per_traj
return out, blackjax_mclmc_sampler_params, num_steps_per_traj * calls_per_integrator_step(coefficients)

return s

Expand Down Expand Up @@ -108,7 +112,7 @@ def s(logdensity_fn, num_steps, initial_position, transform, key):
# print(info.acceptance_rate.mean(), "acceptance probability\n\n\n\n")
# print(out.var(axis=0), "acceptance probability")

return out, MCLMCAdaptationState(L=L, step_size=step_size, std_mat=1.), num_steps_per_traj * calls_per_integrator_step[integrator]
return out, MCLMCAdaptationState(L=L, step_size=step_size, std_mat=1.), num_steps_per_traj * calls_per_integrator_step(coefficients)

return s

Expand Down
80 changes: 80 additions & 0 deletions blackjax/benchmarks/mcmc/results.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" Brownian Motion\n",
"0 mclmc\n",
"1 (0.08398315262876693, 0.2539785108410595, 0.68...\n",
"2 0.0\n",
"3 inf\n"
]
},
{
"ename": "ValueError",
"evalue": "Length mismatch: Expected axis has 1 elements, new values have 3 elements",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/Users/reubencohn-gordon/Desktop/blackjax/blackjax/benchmarks/mcmc/results.ipynb Cell 1\u001b[0m line \u001b[0;36m8\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/reubencohn-gordon/Desktop/blackjax/blackjax/benchmarks/mcmc/results.ipynb#W0sZmlsZQ%3D%3D?line=5'>6</a>\u001b[0m \u001b[39mprint\u001b[39m(df)\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/reubencohn-gordon/Desktop/blackjax/blackjax/benchmarks/mcmc/results.ipynb#W0sZmlsZQ%3D%3D?line=6'>7</a>\u001b[0m \u001b[39m# Name the columns\u001b[39;00m\n\u001b[0;32m----> <a href='vscode-notebook-cell:/Users/reubencohn-gordon/Desktop/blackjax/blackjax/benchmarks/mcmc/results.ipynb#W0sZmlsZQ%3D%3D?line=7'>8</a>\u001b[0m df\u001b[39m.\u001b[39mcolumns \u001b[39m=\u001b[39m [\u001b[39m\"\u001b[39m\u001b[39mColumn1\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mColumn2\u001b[39m\u001b[39m\"\u001b[39m, \u001b[39m\"\u001b[39m\u001b[39mColumn3\u001b[39m\u001b[39m\"\u001b[39m]\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/reubencohn-gordon/Desktop/blackjax/blackjax/benchmarks/mcmc/results.ipynb#W0sZmlsZQ%3D%3D?line=9'>10</a>\u001b[0m \u001b[39m# Plot the data\u001b[39;00m\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/reubencohn-gordon/Desktop/blackjax/blackjax/benchmarks/mcmc/results.ipynb#W0sZmlsZQ%3D%3D?line=10'>11</a>\u001b[0m df\u001b[39m.\u001b[39mplot()\n",
"File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/envs/mclmc/lib/python3.11/site-packages/pandas/core/generic.py:6218\u001b[0m, in \u001b[0;36mNDFrame.__setattr__\u001b[0;34m(self, name, value)\u001b[0m\n\u001b[1;32m 6216\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[1;32m 6217\u001b[0m \u001b[39mobject\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__getattribute__\u001b[39m(\u001b[39mself\u001b[39m, name)\n\u001b[0;32m-> 6218\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mobject\u001b[39m\u001b[39m.\u001b[39m\u001b[39m__setattr__\u001b[39m(\u001b[39mself\u001b[39m, name, value)\n\u001b[1;32m 6219\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mAttributeError\u001b[39;00m:\n\u001b[1;32m 6220\u001b[0m \u001b[39mpass\u001b[39;00m\n",
"File \u001b[0;32mproperties.pyx:69\u001b[0m, in \u001b[0;36mpandas._libs.properties.AxisProperty.__set__\u001b[0;34m()\u001b[0m\n",
"File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/envs/mclmc/lib/python3.11/site-packages/pandas/core/generic.py:767\u001b[0m, in \u001b[0;36mNDFrame._set_axis\u001b[0;34m(self, axis, labels)\u001b[0m\n\u001b[1;32m 762\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 763\u001b[0m \u001b[39mThis is called from the cython code when we set the `index` attribute\u001b[39;00m\n\u001b[1;32m 764\u001b[0m \u001b[39mdirectly, e.g. `series.index = [1, 2, 3]`.\u001b[39;00m\n\u001b[1;32m 765\u001b[0m \u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 766\u001b[0m labels \u001b[39m=\u001b[39m ensure_index(labels)\n\u001b[0;32m--> 767\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_mgr\u001b[39m.\u001b[39mset_axis(axis, labels)\n\u001b[1;32m 768\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_clear_item_cache()\n",
"File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/envs/mclmc/lib/python3.11/site-packages/pandas/core/internals/managers.py:227\u001b[0m, in \u001b[0;36mBaseBlockManager.set_axis\u001b[0;34m(self, axis, new_labels)\u001b[0m\n\u001b[1;32m 225\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mset_axis\u001b[39m(\u001b[39mself\u001b[39m, axis: AxisInt, new_labels: Index) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 226\u001b[0m \u001b[39m# Caller is responsible for ensuring we have an Index object.\u001b[39;00m\n\u001b[0;32m--> 227\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_validate_set_axis(axis, new_labels)\n\u001b[1;32m 228\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39maxes[axis] \u001b[39m=\u001b[39m new_labels\n",
"File \u001b[0;32m/opt/homebrew/Caskroom/miniconda/base/envs/mclmc/lib/python3.11/site-packages/pandas/core/internals/base.py:85\u001b[0m, in \u001b[0;36mDataManager._validate_set_axis\u001b[0;34m(self, axis, new_labels)\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[39mpass\u001b[39;00m\n\u001b[1;32m 84\u001b[0m \u001b[39melif\u001b[39;00m new_len \u001b[39m!=\u001b[39m old_len:\n\u001b[0;32m---> 85\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 86\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mLength mismatch: Expected axis has \u001b[39m\u001b[39m{\u001b[39;00mold_len\u001b[39m}\u001b[39;00m\u001b[39m elements, new \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 87\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mvalues have \u001b[39m\u001b[39m{\u001b[39;00mnew_len\u001b[39m}\u001b[39;00m\u001b[39m elements\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 88\u001b[0m )\n",
"\u001b[0;31mValueError\u001b[0m: Length mismatch: Expected axis has 1 elements, new values have 3 elements"
]
}
],
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Load the CSV file\n",
"df = pd.read_csv(\"../../../results.csv\")\n",
"print(df)\n",
"# Name the columns\n",
"df.columns = [\"Column1\", \"Column2\", \"Column3\"]\n",
"\n",
"# Plot the data\n",
"df.plot()\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "mclmc",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 3554d61

Please sign in to comment.