Skip to content

Commit

Permalink
Example of Optimized SMC inner kernel (#49)
Browse files Browse the repository at this point in the history
* Example of Optimized SMC inner kernel

* adding to index

* code review updates and blackjax code updates
  • Loading branch information
ciguaran authored Dec 7, 2023
1 parent 668d6b5 commit e1cf528
Show file tree
Hide file tree
Showing 2 changed files with 313 additions and 0 deletions.
1 change: 1 addition & 0 deletions book/_toc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ parts:
- file: algorithms/pathfinder.md
- file: algorithms/PeriodicOrbitalMCMC.md
- file: algorithms/TemperedSMC.md
- file: algorithms/TemperedSMCWithOptimizedInnerKernel.md
- caption: Models
chapters:
- file: models/change_of_variable_hmc.md
Expand Down
312 changes: 312 additions & 0 deletions book/algorithms/TemperedSMCWithOptimizedInnerKernel.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
---
jupyter:
jupytext:
formats: ipynb,md
text_representation:
extension: .md
format_name: markdown
format_version: '1.3'
jupytext_version: 1.15.2
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---

<!-- #region tags=["remove_cell"] -->
# Tuning inner kernel parameters of SMC
<!-- #endregion -->

```python
import time
import arviz as az
import jax
import numpy as np
from jax import numpy as jnp
import matplotlib.pyplot as plt
import pandas as pd
import functools

```

This notebook is a continuation of `Use Tempered SMC to Improve Exploration of MCMC Methods`.
In that notebook, we tried sampling from a multimodal distribution using HMC, NUTS
and SMC with an HMC kernel. Only the latter was able to get samples from both modes of the distribution.
Recall that when setting the HMC parameters

```{python}
hmc_parameters = dict(
step_size=1e-4, inverse_mass_matrix=inv_mass_matrix, num_integration_steps=1
)
```
these were fixed across all iterations of SMC. The efficiency of an SMC sampler can be improved by
informing the inner kernel parameters using the particles population. We can tune one or many inner
kernel parameters before mutating the particles in step $i$, using the particles outputted by step $i-1$.
This notebook illustrates such tuning using IRMH (Independent Rosenbluth Metropolis-Hastings) with a multivariate normal proposal distribution.

See Design choice (c) of section 2.1.3 from https://arxiv.org/abs/1808.07730.


```python
n_particles = 4000
```

```python
from jax.scipy.stats import multivariate_normal

def V(x):
return 5 * jnp.sum(jnp.square(x ** 2 - 1))

def prior_log_prob(x):
d = x.shape[0]
return multivariate_normal.logpdf(x, jnp.zeros((d,)), jnp.eye(d))

loglikelihood = lambda x: -V(x)


def density():
linspace = jnp.linspace(-2, 2, 5000).reshape(-1, 1)
lambdas = jnp.linspace(0.0, 1.0, 5)
prior_logvals = jnp.vectorize(prior_log_prob, signature="(d)->()")(linspace)
potential_vals = jnp.vectorize(V, signature="(d)->()")(linspace)
log_res = prior_logvals.reshape(1, -1) - jnp.expand_dims(
lambdas, 1
) * potential_vals.reshape(1, -1)

density = jnp.exp(log_res)
normalizing_factor = jnp.sum(density, axis=1, keepdims=True) * (
linspace[1] - linspace[0]
)
density /= normalizing_factor
return density
```

```python
def initial_particles_multivariate_normal(dimensions, key, n_samples):
return jax.random.multivariate_normal(
key, jnp.zeros(dimensions), jnp.eye(dimensions) * 2, (n_samples,)
)
```

## IRMH without tuning


The proposal distribution is normal with fixed parameters across all iterations.

```python
from blackjax import adaptive_tempered_smc
from blackjax.smc import resampling as resampling, solver
from blackjax import irmh
def irmh_experiment(dimensions, target_ess, num_mcmc_steps):
mean = jnp.zeros(dimensions)
cov = jnp.diag(jnp.ones(dimensions)) * 2

def irmh_proposal_distribution(rng_key):
return jax.random.multivariate_normal(rng_key, mean, cov)

def proposal_logdensity_fn(proposal, state):
return jnp.log(jax.scipy.stats.multivariate_normal.pdf(state.position, mean=mean, cov=cov))

fixed_proposal_kernel = adaptive_tempered_smc(
prior_log_prob,
loglikelihood,
irmh.build_kernel(),
irmh.init,
mcmc_parameters={'proposal_distribution':irmh_proposal_distribution,
'proposal_logdensity_fn': proposal_logdensity_fn},
resampling_fn=resampling.systematic,
target_ess=target_ess,
root_solver=solver.dichotomy,
num_mcmc_steps=num_mcmc_steps
)

def inference_loop(kernel, rng_key, initial_state):
def cond(carry):
_, state, *_ = carry
return state.lmbda < 1

def body(carry):
i, state, op_key, curr_loglikelihood = carry
op_key, subkey = jax.random.split(op_key, 2)
state, info = kernel(subkey, state)
return i + 1, state, op_key, curr_loglikelihood + info.log_likelihood_increment

total_iter, final_state, _, log_likelihood = jax.lax.while_loop(
cond, body, (0, initial_state, rng_key, 0.0)
)

return total_iter, final_state.particles

return fixed_proposal_kernel, inference_loop

```

# IRMH tuning the diagonal of the covariance matrix


Although the proposal distribution is always normal, the mean and diagonal of the covariance matrix are fitted from
the particles outcome of the $i-th$ step, in order to mutate them in the step $i+1$

```python
from blackjax.smc.inner_kernel_tuning import inner_kernel_tuning
from blackjax.smc.tuning.from_particles import particles_covariance_matrix, particles_stds, particles_means

def tuned_irmh_loop(kernel, rng_key, initial_state):
def cond(carry):
_, state, *_ = carry
return state.sampler_state.lmbda < 1

def body(carry):
i, state, op_key = carry
op_key, subkey = jax.random.split(op_key, 2)
state, info = kernel(subkey, state)
return i + 1, state, op_key


def f(initial_state, key):
total_iter, final_state, _ = jax.lax.while_loop(
cond, body, (0, initial_state, key)
)
return total_iter, final_state

total_iter, final_state = f(initial_state, rng_key)
return total_iter, final_state.sampler_state.particles


def tuned_irmh_experiment(dimensions, target_ess, num_mcmc_steps):
def kernel_factory(normal_proposal_parameters):
means, stds = normal_proposal_parameters
cov = jnp.square(jnp.diag(stds))
proposal_distribution = lambda key: jax.random.multivariate_normal(key, means, cov)
def proposal_logdensity_fn(proposal, state):
return jnp.log(jax.scipy.stats.multivariate_normal.pdf(state.position, mean=means, cov=cov))

return functools.partial(irmh.build_kernel(),
proposal_logdensity_fn=proposal_logdensity_fn,
proposal_distribution=proposal_distribution)

kernel_tuned_proposal = inner_kernel_tuning(
logprior_fn=prior_log_prob,
loglikelihood_fn=loglikelihood,
mcmc_factory=kernel_factory,
mcmc_init_fn=irmh.init,
resampling_fn=resampling.systematic,
smc_algorithm=adaptive_tempered_smc,
mcmc_parameters={},
mcmc_parameter_update_fn=lambda state, info: (particles_means(state.particles), particles_stds(state.particles)),
initial_parameter_value=(jnp.zeros(dimensions), jnp.ones(dimensions) * 2),
target_ess=target_ess,
num_mcmc_steps=num_mcmc_steps
)

return kernel_tuned_proposal, tuned_irmh_loop
```

# IRMH tuning the covariance matrix.


In this case not only the diagonal but all elements of the covariance matrix are fitted based on the outcome particles.

```python
def irmh_full_cov_experiment(dimensions, target_ess, num_mcmc_steps):
def factory(normal_proposal_parameters):
means, cov = normal_proposal_parameters
proposal_distribution = lambda key: jax.random.multivariate_normal(key, means, cov)
def proposal_logdensity_fn(proposal, state):
return jnp.log(jax.scipy.stats.multivariate_normal.pdf(state.position, mean=means, cov=cov))

return functools.partial(irmh.build_kernel(),
proposal_distribution=proposal_distribution,
proposal_logdensity_fn=proposal_logdensity_fn)

def mcmc_parameter_update_fn(state, info):
covariance = jnp.atleast_2d(particles_covariance_matrix(state.particles))
return particles_means(state.particles), covariance

kernel_tuned_proposal = inner_kernel_tuning(
logprior_fn=prior_log_prob,
loglikelihood_fn=loglikelihood,
mcmc_factory=factory,
mcmc_init_fn=irmh.init,
resampling_fn=resampling.systematic,
smc_algorithm=adaptive_tempered_smc,
mcmc_parameters={},
mcmc_parameter_update_fn=mcmc_parameter_update_fn,
initial_parameter_value=(jnp.zeros(dimensions), jnp.eye(dimensions) * 2),
target_ess=target_ess,
num_mcmc_steps=num_mcmc_steps
)

return kernel_tuned_proposal, tuned_irmh_loop

```

```python
def smc_run_experiment(runnable, target_ess, num_mcmc_steps, dimen):
key = jax.random.PRNGKey(124345)
key, initial_particles_key, iterations_key = jax.random.split(key, 3)
initial_particles = initial_particles_multivariate_normal(dimen, initial_particles_key, n_particles)
kernel, inference_loop = runnable(dimen, target_ess, num_mcmc_steps)
_, particles = inference_loop(kernel.step, iterations_key, kernel.init(initial_particles))
return particles
```

```python
dimensions_to_try = [1, 10, 20, 30, 40, 50, 60]
```

```python
experiments = []
dimensions = []
particles = []
for dims in dimensions_to_try:
for exp_id, experiment in (("irmh", irmh_experiment),
("tune_diag",tuned_irmh_experiment),
("tune_full_cov", irmh_full_cov_experiment)):
experiment_particles = smc_run_experiment(experiment, 0.5, 50 , dims)
experiments.append(exp_id)
dimensions.append(dims)
particles.append(experiment_particles)
```

```python
results = pd.DataFrame({"experiment":experiments,
"dimensions":dimensions,
"particles":particles})
```

```python
linspace = jnp.linspace(-2, 2, 5000).reshape(-1, 1).squeeze()
def plot(post, sampler, dimensions, ax):
dimensions = post.shape[1]
for dim in range(dimensions):
az.plot_kde(post[:,dim], ax=ax)
_ = ax.plot(linspace, density()[-1], c='red')
```

```python
rows=7
cols=3
samplers = ['irmh', 'tune_diag', 'tune_full_cov']
fig, axs = plt.subplots(rows, cols, figsize=(50, 30))

plt.rcParams.update({'font.size': 22})

for ax, lab in zip(axs[:,0], dimensions_to_try):
ax.set(ylabel=f"Dimensions = {lab}")

for ax, lab in zip(axs[0,:], samplers):
ax.set(title=lab)

for col, experiment in enumerate(samplers):
for row, dimension in enumerate(dimensions_to_try):
particles = results[(results.experiment == experiment) & (results.dimensions == dimension)].iloc[0].particles
plot(particles, experiment, dimension, axs[row, col])

fig.tight_layout()
fig.suptitle("""Sampler comparison for increasing number of posterior dimensions.
Each plot displays all dimensions from the posterior, overlayed. The red curve is the actual marginal distribution.""")
plt.show()
```

As seen in the previous figure, as dimensions increase, performance degrades. More tuning, less performance degradation.

0 comments on commit e1cf528

Please sign in to comment.