Skip to content

Commit

Permalink
MOVE BENCHMARKS INSIDE BLACKJAX SRC AND RUN BENCHMARKS IN BENCHMARKS.PY
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Mar 8, 2024
2 parents cb03a18 + 3c338c2 commit d7f3ee4
Show file tree
Hide file tree
Showing 13 changed files with 12,037 additions and 100 deletions.
341 changes: 299 additions & 42 deletions blackjax/adaptation/mclmc_adaptation.py

Large diffs are not rendered by default.

9,716 changes: 9,716 additions & 0 deletions blackjax/benchmarks/mcmc/benchmark.ipynb

Large diffs are not rendered by default.

201 changes: 201 additions & 0 deletions blackjax/benchmarks/mcmc/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
from collections import defaultdict
import itertools
from inference_gym import using_jax as gym
import jax
import jax.numpy as jnp
import numpy as np
from blackjax.mcmc.mhmclmc import rescale
from sampling_algorithms import samplers
from inference_models import models
import blackjax
from blackjax.util import run_inference_algorithm
# from jax import config
# config.update("jax_debug_nans", True)

import matplotlib.pyplot as plt


def get_num_latents(target):
return target.ndims
# return int(sum(map(np.prod, list(jax.tree_flatten(target.event_shape)[0]))))

def err(f_true, var_f, contract = jnp.max):
"""Computes the error b^2 = (f - f_true)^2 / var_f
Args:
f: E_sampler[f(x)], can be a vector
f_true: E_true[f(x)]
var_f: Var_true[f(x)]
contract: how to combine a vector f in a single number, can be for example jnp.average or jnp.max
Returns:
contract(b^2)
"""

def _err(f):
bsq = jnp.square(f - f_true) / var_f
# jax.debug.print("bsq {x}", x=(f - f_true, f_true, f))
# print(bsq.shape, "shape ASDFADSF \n\n")
return contract(bsq)

return jax.vmap(_err)



def grads_to_low_error(err_t, low_error= 0.01, grad_evals_per_step= 1):
"""Uses the error of the expectation values to compute the effective sample size neff
b^2 = 1/neff"""

cutoff_reached = err_t[-1] < low_error
return find_crossing(err_t, low_error) * grad_evals_per_step, cutoff_reached



def ess(err_t, grad_evals_per_step, neff= 100):

low_error = 1./neff
cutoff_reached = err_t[-1] < low_error
crossing = find_crossing(err_t, low_error)
# print(len(err_t), "len err t")

# print("crossing", crossing, (crossing * grad_evals_per_step), neff / (crossing * grad_evals_per_step))
# print((err_t)[-100:], "le")

return (neff / (crossing * grad_evals_per_step)) * cutoff_reached



def find_crossing(array, cutoff):
"""the smallest M such that array[m] < cutoff for all m > M"""

b = array > cutoff
indices = jnp.argwhere(b)
if indices.shape[0] == 0:
print("\n\n\nNO CROSSING FOUND!!!\n\n\n", array, cutoff)
return 1
# print(jnp.argwhere(array))
return jnp.max(indices)+1

def cumulative_avg(samples):
return jnp.cumsum(samples, axis = 0) / jnp.arange(1, samples.shape[0] + 1)[:, None]



# def benchmark(model, sampler):

# # print(find_crossing(jnp.array([0.4, 0.2, 0.3, 0.4, 0.5, 0.2, 0.2]), 0.3))
# # print(cumulative_avg(jnp.array([[1., 2.], [1.,2.]]).T))
# # raise Exception

# n = 10000

# identity_fn = model.sample_transformations['identity']
# # print('True mean', identity_fn.ground_truth_mean)
# # print('True std', identity_fn.ground_truth_standard_deviation)
# # print("Empirical mean", samples.mean(axis=0))
# # print("Empirical std", samples.std(axis=0))

# logdensity_fn = model.unnormalized_log_prob
# d = get_num_latents(model)
# initial_position = jax.random.normal(jax.random.PRNGKey(0), (d,))
# samples, num_steps_per_traj = sampler(logdensity_fn, n, initial_position, jax.random.PRNGKey(0))
# # print(samples[-1], samples[0], "samps", samples.shape)

# favg, fvar = identity_fn.ground_truth_mean, identity_fn.ground_truth_standard_deviation**2
# err_t = err(favg, fvar, jnp.average)(cumulative_avg(samples))
# # print(err_t[-1], "benchmark err_t[0]")
# ess_per_sample = ess(err_t, grad_evals_per_step=2)

# return ess_per_sample

def benchmark_chains(model, sampler, favg, fvar, n=10000, batch=None):


# print(model)
# print(model.sample_transformations.keys())
# raise Exception
# identity_fn = model.sample_transformations['identity']
logdensity_fn = lambda x : -model.nlogp(x)
d = get_num_latents(model)
if batch is None:
batch = np.ceil(1000 / d).astype(int)
key, init_key = jax.random.split(jax.random.PRNGKey(44), 2)
keys = jax.random.split(key, batch)
# keys = jnp.array([jax.random.PRNGKey(0)])
init_keys = jax.random.split(init_key, batch)
# print(init_keys.shape,)
# raise Exception
init_pos = jax.vmap(model.sample)(init_keys) # jax.random.normal(key=init_key, shape=(batch, d))
# print(init_pos.shape, "init pos")

samples, params, avg_num_steps_per_traj = jax.vmap(lambda pos, key: sampler(logdensity_fn, n, pos, key))(init_pos, keys)
avg_num_steps_per_traj = jnp.mean(avg_num_steps_per_traj, axis=0)
# print(samples, samples.shape)
# print("\n\n\n\nAVG NUM STEPS PER TRAJ", avg_num_steps_per_traj)
# print(samples[0][-1], samples[0][0], "samps chain", samples.shape)

# identity_fn.ground_truth_mean, identity_fn.ground_truth_standard_deviation**2
full = lambda arr : err(favg, fvar, jnp.average)(cumulative_avg(arr))
err_t = jnp.mean(jax.vmap(full)(samples**2), axis=0)
# err_t = jax.vmap(full)(samples)[1]
# print(err_t[-1], "benchmark chains err_t[0]")
# print(avg_num_steps_per_traj, "AVG\n\n")
# raise Exception
ess_per_sample = ess(err_t, grad_evals_per_step=2 * avg_num_steps_per_traj)

# print('True mean', identity_fn.ground_truth_mean)
# print('True std', identity_fn.ground_truth_standard_deviation)
# print("Empirical mean", samples.mean(axis=[0,1]))
# print("Empirical std", samples.std(axis=[0,1]))

# print(params.L.mean(), params.step_size.mean(), "params")

# print('True E[x^2]', identity_fn.ground_truth_mean)
# print('True std[x^2]', identity_fn.ground_truth_standard_deviation)



return ess_per_sample, err_t[-1], params

def run_benchmarks(n):

for model, sampler in itertools.product(models, samplers):

print(f"\nModel: {model}, Sampler: {sampler}\n")


result, bias, _ = benchmark_chains(models[model], samplers[sampler], n=n, batch=100//models[model].ndims,favg=models[model].E_x2, fvar=models[model].Var_x2)
print(f"ESS: {result.item()}")


if __name__ == "__main__":

run_benchmarks(5000)

# # Extract the models and samplers from the results dictionary
# models = [model for model, _ in results.keys()]
# samplers = [sampler for _, sampler in results.keys()]

# # Extract the corresponding results
# results_values = list(results.values())

# # Create a figure with two subplots
# fig, axs = plt.subplots(1, 2, figsize=(10, 5))

# # Plot the results in the first subplot
# axs[0].bar(range(len(results)), results_values)
# axs[0].set_xticks(range(len(results)))
# axs[0].set_xticklabels(['{} - {}'.format(model, sampler) for model, sampler in zip(models, samplers)], rotation=90)
# axs[0].set_title('Benchmark Results')

# # Plot the results in the second subplot
# axs[1].bar(range(len(results)), results_values, color='orange')
# axs[1].set_xticks(range(len(results)))
# axs[1].set_xticklabels(['{} - {}'.format(model, sampler) for model, sampler in zip(models, samplers)], rotation=90)
# axs[1].set_title('Benchmark Results')

# # Adjust the layout of the subplots
# plt.tight_layout()

# # Show the plot
# plt.show()

48 changes: 33 additions & 15 deletions explore.py → blackjax/benchmarks/mcmc/explore.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,19 @@
import jax

from datetime import date
from blackjax.benchmarks.mcmc.benchmark import benchmark_chains

from blackjax.benchmarks.mcmc.inference_models import IllConditionedGaussian

rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))

import blackjax
import numpy as np
import jax.numpy as jnp
from sampling_algorithms import samplers
from inference_models import StandardNormal, models

def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform, std_mat):
def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform, std_mat, L, step_size):
init_key, tune_key, run_key = jax.random.split(key, 3)

# create an initial state for the sampler
Expand All @@ -20,8 +25,8 @@ def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform, std_ma
# use the quick wrapper to build a new kernel with the tuned parameters
sampling_alg = blackjax.mclmc(
logdensity_fn,
L=4.0,
step_size=1.,
L=L,
step_size=step_size,
std_mat=std_mat,
)

Expand All @@ -35,7 +40,7 @@ def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform, std_ma
progress_bar=True,
)

return samples
return samples, None, 1


def run_mclmc_with_tuning(logdensity_fn, num_steps, initial_position, key, transform):
Expand Down Expand Up @@ -70,9 +75,9 @@ def run_mclmc_with_tuning(logdensity_fn, num_steps, initial_position, key, trans
# use the quick wrapper to build a new kernel with the tuned parameters
sampling_alg = blackjax.mclmc(
logdensity_fn,
L=4.0,
step_size=1.,
std_mat=std_mat,
L=blackjax_mclmc_sampler_params.L,
step_size=blackjax_mclmc_sampler_params.step_size,
std_mat=blackjax_mclmc_sampler_params.std_mat,
)

# run the sampler
Expand Down Expand Up @@ -103,12 +108,25 @@ def run_mclmc_with_tuning(logdensity_fn, num_steps, initial_position, key, trans
# print(samples.var(axis=0))

# den = lambda x: jax.scipy.stats.norm.logpdf(x, loc=0., scale=jnp.sqrt(sigma)).sum()
# print(IllConditionedGaussian(2, 2).E_x2)
# samples = run_mclmc_with_tuning(
# logdensity_fn=lambda x : - IllConditionedGaussian(2, 2).nlogp(x),
# num_steps=1000000,
# initial_position=jnp.ones((2,)),
# key=sample_key,
# transform=lambda x: x.position[:2],
# )
# print(samples.var(axis=0))
m = IllConditionedGaussian(2, 5)
# m = StandardNormal(10)
# sampler = lambda logdensity_fn, num_steps, initial_position, key: run_mclmc(logdensity_fn=logdensity_fn, num_steps=num_steps, initial_position=initial_position, key=key, transform=lambda x:x.position,
# std_mat=jnp.ones((2,))
# # std_mat=m.E_x2
# , L=1.7296296, step_size=0.814842)
print(m.E_x2)

sampler = 'mclmc'

result, bias, _ = benchmark_chains(m, samplers[sampler], n=5000, batch=1000//m.ndims,favg=m.E_x2, fvar=m.Var_x2)

samples = run_mclmc_with_tuning(
logdensity_fn=lambda x: -0.5 * jnp.sum(jnp.square(x)),
num_steps=10000,
initial_position=jnp.ones((2,)),
key=sample_key,
transform=lambda x: x.position[:2],
)
print(samples.var())
print(result)
Loading

0 comments on commit d7f3ee4

Please sign in to comment.