Skip to content

Commit

Permalink
some bug
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed Mar 31, 2024
1 parent 21d07c3 commit 4962d5a
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 109 deletions.
69 changes: 21 additions & 48 deletions blackjax/benchmarks/mcmc/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,96 +87,69 @@ def benchmark_chains(model, sampler, key, n=10000, batch=None, contract = jnp.av

# samples, params, avg_num_steps_per_traj = jax.pmap(lambda pos, key: sampler(model.logdensity_fn, n, pos, model.transform, key))(init_pos, keys)
samples, params, grad_calls_per_traj = jax.vmap(lambda pos, key: sampler(logdensity_fn=model.logdensity_fn, num_steps=n, initial_position= pos,transform= model.transform, key=key))(init_pos, keys)
# avg_grad_calls_per_traj = jnp.mean(jnp.where(jnp.isnan(grad_calls_per_traj), 1, grad_calls_per_traj), axis=0)
avg_grad_calls_per_traj = jnp.nanmean(grad_calls_per_traj, axis=0)
try:
print(jnp.nanmean(params.step_size,axis=0), jnp.nanmean(params.L,axis=0))
except: pass
# print("grad calls", avg_grad_calls_per_traj)

full = lambda arr : err(model.E_x2, model.Var_x2, contract)(cumulative_avg(arr))
# err_t = jnp.mean(jax.vmap(full)(samples**2), axis=0)
err_t = jax.vmap(full)(samples**2)
err_t_median = jnp.median(err_t, axis=0)
# raise Exception
# print(err_t.shape)
# foo = jax.vmap(lambda x: calculate_ess(x, grad_evals_per_step=avg_grad_calls_per_traj))(err_t)
# print(foo.shape)


# outs = [calculate_ess(b, grad_evals_per_step=avg_grad_calls_per_traj) for b in err_t]
# print(outs[:10])
# # print(outs[:10])
# esses = [i[0].item() for i in outs if not math.isnan(i[0].item())]
# grad_calls = [i[1].item() for i in outs if not math.isnan(i[1].item())]
# print(grad_calls)
# raise Exception
# return(mean(esses), mean(grad_calls))

esses, grad_calls, _ = calculate_ess(err_t_median, grad_evals_per_step=avg_grad_calls_per_traj)

# print(mean(esses), median(esses))
# print(mean(grad_calls), median(grad_calls))

# return grads_to_low_error(err_t, avg_grad_calls_per_traj)[0]
# ess_per_sample = calculate_ess(err_t, grad_evals_per_step=avg_grad_calls_per_traj)
err_t_median = jnp.median(err_t, axis=0)
esses, grad_calls, _ = calculate_ess(err_t_median, grad_evals_per_step=avg_grad_calls_per_traj)
return esses, grad_calls
# , err_t[-1], params




def run_benchmarks(batch_size):

results = defaultdict(tuple)
for variables in itertools.product(
# ["mhmclmc", "nuts", "mclmc", ],
["mhmclmc",],
[StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(10000), 10)).astype(int)],
# [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients],
[mclachlan_coefficients],
):

sampling_algs = ["mhmclmc", "nuts", "mclmc", ]
coeffs = [velocity_verlet_coefficients, mclachlan_coefficients, yoshida_coefficients, omelyan_coefficients]
models = [StandardNormal(d) for d in np.ceil(np.logspace(np.log10(10), np.log10(100000), 5)).astype(int)]

# for model, sampler in itertools.product(models, sampling_algs):
for variables in itertools.product(models, sampling_algs, coeffs):
sampler, model, coefficients = variables
num_chains = 1 + batch_size//model.ndims

model, sampler, coefficients = variables
print(f"\nModel: {model.name}, Sampler: {sampler}\n Coefficients: {coefficients}\n")
# sampler_to_integrator_type = {
# "mclmc": generate_isokinetic_integrator,
# "mhmclmc": generate_isokinetic_integrator,
# "nuts": generate_euclidean_integrator,
# }
print(f"\nModel: {model.name,model.ndims}, Sampler: {sampler}\n Coefficients: {coefficients}\nNumber of chains {num_chains}",)

# Model = models[model][0]
key = jax.random.PRNGKey(2)
for i in range(1):
key1, key = jax.random.split(key)
# integrator = sampler_to_integrator_type[sampler](coefficients)
ess, grad_calls = benchmark_chains(model, partial(samplers[sampler], coefficients=coefficients),key1, n=1000, batch=1 + 1000//model.ndims, contract=jnp.average)
#print(f"ESS: {result.item()}")
ess, grad_calls = benchmark_chains(model, partial(samplers[sampler], coefficients=coefficients),key1, n=500, batch=num_chains, contract=jnp.average)

print(f"grads to low bias: {grad_calls}")
# results.append(result[1])
results[((model.name, model.ndims), sampler, name_integrator(coefficients))] = (ess, grad_calls)

# import matplotlib.pyplot as plt
results[((model.name, model.ndims), sampler, name_integrator(coefficients))] = (ess, grad_calls)

# # ... existing code ...

# # Plot the second_elements in a scatterplot
# plt.scatter([0.5]*len(results), results)
# plt.xlabel("Iteration")
# plt.ylabel("Second Element of Results")
# plt.title("Scatterplot of Second Element of Results")
# plt.savefig("scatterplot_mclmc.png") # Save the plot as scatterplot.png
# plt.show()

print(results)

import pandas as pd

df = pd.Series(results).reset_index()
df.columns = ["model", "sampler", "coeffs", "result"]
df.result = df.result.apply(lambda x: x[1].item())
df.result = df.result.apply(lambda x: x[0].item())
df.model = df.model.apply(lambda x: x[1])
df.to_csv("results.csv", index=False)

return results

if __name__ == "__main__":
run_benchmarks(batch_size=100)
run_benchmarks(batch_size=10000)


2 changes: 1 addition & 1 deletion blackjax/benchmarks/mcmc/sampling_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def run_mclmc(coefficients, logdensity_fn, num_steps, initial_position, transfor
num_steps=num_steps,
state=initial_state,
rng_key=tune_key,
diagonal_preconditioning=True
diagonal_preconditioning=False
)

# jax.debug.print("params {x}", x=(blackjax_mclmc_sampler_params.L, blackjax_mclmc_sampler_params.step_size))
Expand Down
70 changes: 10 additions & 60 deletions results.csv
Original file line number Diff line number Diff line change
@@ -1,61 +1,11 @@
model,sampler,coeffs,result
10,mhmclmc,leapfrog,1313.87060546875
10,mhmclmc,mclachlan,4735.89697265625
10,mhmclmc,yoshida,4282.994140625
10,mhmclmc,omelyan,6871.47314453125
10,nuts,leapfrog,1509.970947265625
10,nuts,mclachlan,1601.101318359375
10,nuts,yoshida,2225.653076171875
10,nuts,omelyan,inf
10,mclmc,leapfrog,285.0
10,mclmc,mclachlan,240.0
10,mclmc,yoshida,357.0
10,mclmc,omelyan,525.0
100,mhmclmc,leapfrog,2180.30078125
100,mhmclmc,mclachlan,5657.84912109375
100,mhmclmc,yoshida,5520.23046875
100,mhmclmc,omelyan,8336.8203125
100,nuts,leapfrog,2479.009033203125
100,nuts,mclachlan,3409.22216796875
100,nuts,yoshida,2718.000244140625
100,nuts,omelyan,15432.2177734375
100,mclmc,leapfrog,347.0
100,mclmc,mclachlan,260.0
100,mclmc,yoshida,669.0
100,mclmc,omelyan,5250.0
1000,mhmclmc,leapfrog,2822.58984375
1000,mhmclmc,mclachlan,2185.3720703125
1000,mhmclmc,yoshida,6953.826171875
1000,mhmclmc,omelyan,74472.296875
1000,nuts,leapfrog,3750.000244140625
1000,nuts,mclachlan,11058.0009765625
1000,nuts,yoshida,2493.000244140625
1000,nuts,omelyan,4380.00048828125
1000,mclmc,leapfrog,343.0
1000,mclmc,mclachlan,270.0
1000,mclmc,yoshida,357.0
1000,mclmc,omelyan,inf
10000,mhmclmc,leapfrog,2111.00048828125
10000,mhmclmc,mclachlan,3084.526611328125
10000,mhmclmc,yoshida,4508.38330078125
10000,mhmclmc,omelyan,28683.00390625
10000,nuts,leapfrog,10230.0009765625
10000,nuts,mclachlan,3934.000244140625
10000,nuts,yoshida,3024.000244140625
10000,nuts,omelyan,4725.00048828125
10000,mclmc,leapfrog,343.0
10000,mclmc,mclachlan,244.0
10000,mclmc,yoshida,627.0
10000,mclmc,omelyan,850.0
100000,mhmclmc,leapfrog,3089.21923828125
100000,mhmclmc,mclachlan,5451.79541015625
100000,mhmclmc,yoshida,5016.7548828125
100000,mhmclmc,omelyan,8191.017578125
100000,nuts,leapfrog,17766.001953125
100000,nuts,mclachlan,4256.00048828125
100000,nuts,yoshida,12870.0009765625
100000,nuts,omelyan,10150.0009765625
100000,mclmc,leapfrog,396.0
100000,mclmc,mclachlan,248.0
100000,mclmc,yoshida,420.0
100000,mclmc,omelyan,1000.0
10,mhmclmc,mclachlan,0.025606706738471985
22,mhmclmc,mclachlan,0.027042334899306297
47,mhmclmc,mclachlan,0.02515588514506817
100,mhmclmc,mclachlan,0.024291083216667175
216,mhmclmc,mclachlan,0.027843158692121506
465,mhmclmc,mclachlan,0.02349252626299858
1000,mhmclmc,mclachlan,0.023028362542390823
2155,mhmclmc,mclachlan,0.021098610013723373
4642,mhmclmc,mclachlan,0.03155219554901123
10000,mhmclmc,mclachlan,0.022297026589512825

0 comments on commit 4962d5a

Please sign in to comment.