Skip to content

Commit

Permalink
UPDATE EXAMPLE
Browse files Browse the repository at this point in the history
  • Loading branch information
reubenharry committed May 24, 2024
1 parent 9fd1824 commit 420468c
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 834 deletions.
827 changes: 0 additions & 827 deletions book/algorithms/mclmc.ipynb

This file was deleted.

132 changes: 125 additions & 7 deletions book/algorithms/mclmc.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.15.2
jupytext_version: 1.16.0
kernelspec:
display_name: mclmc
language: python
Expand Down Expand Up @@ -59,6 +59,7 @@ rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))

```{code-cell} ipython3
import blackjax
from blackjax.mcmc.adjusted_mclmc import rescale
import numpy as np
import jax.numpy as jnp
```
Expand All @@ -73,9 +74,10 @@ def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform):
)
# build the kernel
kernel = blackjax.mcmc.mclmc.build_kernel(
kernel = lambda sqrt_diag_cov_mat : blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=logdensity_fn,
integrator=blackjax.mcmc.integrators.isokinetic_mclachlan,
sqrt_diag_cov_mat=sqrt_diag_cov_mat,
)
# find values for L and step_size
Expand All @@ -87,6 +89,7 @@ def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform):
num_steps=num_steps,
state=initial_state,
rng_key=tune_key,
diagonal_preconditioning=False,
)
# use the quick wrapper to build a new kernel with the tuned parameters
Expand Down Expand Up @@ -133,7 +136,6 @@ plt.title("Scatter Plot of Samples")

This is ported from Jakob Robnik's [example notebook](https://github.com/JakobRobnik/MicroCanonicalHMC/blob/master/notebooks/tutorials/advanced_tutorial.ipynb)


```{code-cell} ipython3
import matplotlib.dates as mdates
Expand Down Expand Up @@ -167,7 +169,7 @@ dim = 2429
lambda_sigma, lambda_nu = 50, 0.1
def logp(x):
def logp_volatility(x):
"""log p of the target distribution"""
sigma = (
Expand Down Expand Up @@ -234,7 +236,7 @@ def prior_draw(key):
```{code-cell} ipython3
key1, key2, rng_key = jax.random.split(rng_key, 3)
samples = run_mclmc(
logdensity_fn=logp,
logdensity_fn=logp_volatility,
num_steps=10000,
initial_position=prior_draw(key1),
key=key2,
Expand Down Expand Up @@ -278,18 +280,134 @@ ax.fill_between(dates, lower_quartile, upper_quartile, color="navy", alpha=0.5)
ax.legend()
```


## Adjusted MCLMC

Blackjax also provides an adjusted version of the algorithm. This also has two hyperparameters, `step_size` and `L`. `L` is related to the `L` parameter of the unadjusted version, but not identical. The tuning algorithm is also similar, but uses a dual averaging scheme to tune the step size. We find in practice that a target MH acceptance rate of 0.9 is a good choice.

```{code-cell} ipython3
def run_adjusted_mclmc(logdensity_fn, num_steps, initial_position, key, transform):
init_key, tune_key, run_key = jax.random.split(key, 3)
# create an initial state for the sampler
initial_state = blackjax.mcmc.adjusted_mclmc.init(
position=initial_position, logdensity_fn=logdensity_fn, random_generator_arg=init_key
)
kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov_mat: blackjax.mcmc.adjusted_mclmc.build_kernel(
integrator=blackjax.mcmc.integrators.isokinetic_mclachlan,
integration_steps_fn = lambda k : jnp.ceil(jax.random.uniform(k) * rescale(avg_num_integration_steps)),
sqrt_diag_cov_mat=sqrt_diag_cov_mat,
)(
rng_key=rng_key,
state=state,
step_size=step_size,
logdensity_fn=logdensity_fn)
(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
params_history,
final_da
) = blackjax.adaptation.mclmc_adaptation.adjusted_mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
state=initial_state,
rng_key=tune_key,
target=0.9,
frac_tune1=0.1,
frac_tune2=0.1,
frac_tune3=0.1,
diagonal_preconditioning=False,
)
step_size = blackjax_mclmc_sampler_params.step_size
L = blackjax_mclmc_sampler_params.L
alg = blackjax.adjusted_mclmc(
logdensity_fn=logdensity_fn,
step_size=step_size,
integration_steps_fn = lambda key: jnp.ceil(jax.random.uniform(key) * rescale(L/step_size)) ,
integrator=blackjax.mcmc.integrators.isokinetic_mclachlan,
sqrt_diag_cov_mat=blackjax_mclmc_sampler_params.sqrt_diag_cov_mat,
)
_, samples, info = blackjax.util.run_inference_algorithm(
rng_key=run_key,
initial_state=blackjax_state_after_tuning,
inference_algorithm=alg,
num_steps=num_steps,
transform=lambda x: x.position,
progress_bar=True)
return samples
```

```{code-cell} ipython3
# run the algorithm on a high dimensional gaussian, and show two of the dimensions
sample_key, rng_key = jax.random.split(rng_key)
samples = run_adjusted_mclmc(
logdensity_fn=lambda x: -0.5 * jnp.sum(jnp.square(x)),
num_steps=1000,
initial_position=jnp.ones((1000,)),
key=sample_key,
transform=lambda x: x.position[:2],
)
plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1)
plt.axis("equal")
plt.title("Scatter Plot of Samples")
```

```{code-cell} ipython3
key1, key2, rng_key = jax.random.split(rng_key, 3)
samples = run_adjusted_mclmc(
logdensity_fn=logp_volatility,
num_steps=10000,
initial_position=prior_draw(key1),
key=key2,
transform=lambda x: x,
)
R = np.array(samples)[:, :-2] # remove sigma and nu parameters
R = np.sort(R, axis=0) # sort samples for each R_n
num_samples = len(R)
lower_quartile, median, upper_quartile = (
R[num_samples // 4, :],
R[num_samples // 2, :],
R[3 * num_samples // 4, :],
)
# figure setup
_, ax = plt.subplots(figsize=(12, 5))
ax.spines["right"].set_visible(False) # remove the upper and the right axis lines
ax.spines["top"].set_visible(False)
ax.xaxis.set_major_locator(mdates.YearLocator()) # dates on the xaxis
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y"))
ax.xaxis.set_minor_locator(mdates.MonthLocator())
# plot data
ax.plot(dates, SP500_returns, ".", markersize=3, color="steelblue")
ax.plot(
[], [], ".", markersize=10, color="steelblue", alpha=0.5, label="data"
) # larger markersize for the legend
ax.set_xlabel("time")
ax.set_ylabel("S&P500 returns")
# plot posterior
ax.plot(dates, median, color="navy", label="volatility posterior")
ax.fill_between(dates, lower_quartile, upper_quartile, color="navy", alpha=0.5)
ax.legend()
```

```{bibliography}
:filter: docname in docnames
```


```
```

0 comments on commit 420468c

Please sign in to comment.