From 2381759da93d4aaf1ad38bdacc539707b98a988e Mon Sep 17 00:00:00 2001 From: = Date: Tue, 13 Aug 2024 14:36:34 -0400 Subject: [PATCH 1/3] update run_inference_algorithm --- book/algorithms/mclmc.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/book/algorithms/mclmc.md b/book/algorithms/mclmc.md index 499e3c49..65e8bee8 100644 --- a/book/algorithms/mclmc.md +++ b/book/algorithms/mclmc.md @@ -101,7 +101,7 @@ def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform): ) # run the sampler - _, samples, _ = blackjax.util.run_inference_algorithm( + _, samples = blackjax.util.run_inference_algorithm( rng_key=run_key, initial_state=blackjax_state_after_tuning, inference_algorithm=sampling_alg, @@ -122,7 +122,7 @@ samples = run_mclmc( num_steps=1000, initial_position=jnp.ones((1000,)), key=sample_key, - transform=lambda x: x.position[:2], + transform=lambda state, info: state.position[:2], ) samples.mean() ``` From dedd675a27cd4d86abb089d740fa062bece13547 Mon Sep 17 00:00:00 2001 From: = Date: Fri, 20 Sep 2024 14:50:36 -0400 Subject: [PATCH 2/3] fix stoch vol example --- book/algorithms/mclmc.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/book/algorithms/mclmc.md b/book/algorithms/mclmc.md index 65e8bee8..1e40af6e 100644 --- a/book/algorithms/mclmc.md +++ b/book/algorithms/mclmc.md @@ -241,7 +241,7 @@ samples = run_mclmc( num_steps=10000, initial_position=prior_draw(key1), key=key2, - transform=lambda x: x, + transform=lambda state, info: state, ) samples = transform(samples.position) From f5f0f330277365b2ec82bbb242af1eef557dd253 Mon Sep 17 00:00:00 2001 From: = Date: Mon, 23 Sep 2024 15:12:58 -0400 Subject: [PATCH 3/3] add tuning example --- book/algorithms/mclmc.md | 316 ++++++++++++++++++++++++--------------- 1 file changed, 195 insertions(+), 121 deletions(-) diff --git a/book/algorithms/mclmc.md b/book/algorithms/mclmc.md index 1e40af6e..10fcdc53 100644 --- a/book/algorithms/mclmc.md +++ b/book/algorithms/mclmc.md @@ -17,7 +17,7 @@ This is an algorithm based on https://arxiv.org/abs/2212.08549 ({cite:p}`robnik2 -The original derivation comes from thinking about the microcanonical ensemble (a concept from statistical mechanics), but the upshot is that we integrate the following SDE: +The idea is that we have a distribution $p(x)$ from which we want to sample. We numerically integrate the following SDE; the samples we obtain converge (in the limit of many steps and small step size) to samples from the target distribution. $$ \frac{d}{dt}\begin{bmatrix} @@ -31,13 +31,18 @@ u \\ \end{bmatrix} $$ -where $u$ is an auxilliary variable, $S(x)$ is the negative log PDF of the distribution from which we are sampling and the last term describes spherically symmetric noise. After $u$ is marginalized out, this converges to the target PDF, $p(x) \propto e^{-S(x)}$. +Here $x \in \mathbb{R}^n$ is the variable of interest (i.e. the variable of the target distribution $p$), $u \in \mathbb{S}^{n-1}$ is the momentum (i.e. $u$ lives in $\mathbb{R}^n$ but is constrained to have fixed norm), $S(x)$ is the negative log PDF of the distribution from which we are sampling, and $P$ is the projection operator. The term $\eta P(u)dW$ describes spherically symmetric noise on the $n-1$ sphere $\mathbb{S}^{n-1}$. After $u$ is marginalized out, this converges to the target PDF, $p(x) \propto e^{-S(x)}$. ## How to run MCLMC in BlackJax -It is very important to use the tuning algorithm provided, which controls the step size of the integrator and also $L$, a parameter related to $\eta$ above. +MCLMC has two parameters: -An example is given below, of a 1000 dim Gaussian (of which 2 dimensions are plotted). +* Typical momentum decoherence scale $L$. This adds some noise to the direction of the velocity after every step. $L = \infty$ means no noise, $L = 0$ is full refreshement after every step. +* Stepsize $\epsilon$ of the discretization of the dynamics. While the continuous dynamics converge exactly on the target distribution, the discrete dynamics inject bias into the resulting distribution. As such, we want to find the ideal tradeoff: $\epsilon$ small enough for bias to be minimal, but large enough for computational efficiency. + +MCLMC in Blackjax comes with a tuning algorithm which attempts to find optimal values for both of these parameters. This must be used for good performance. + +An example is given below, of tuning and running a chain for a 1000 dimensional Gaussian target (of which a 2 dimensional marginal is plotted): ```{code-cell} ipython3 :tags: [hide-cell] @@ -47,26 +52,22 @@ import matplotlib.pyplot as plt plt.rcParams["axes.spines.right"] = False plt.rcParams["axes.spines.top"] = False plt.rcParams["font.size"] = 19 -``` - -```{code-cell} ipython3 -:tags: [remove-output] import jax - -from datetime import date - -rng_key = jax.random.key(int(date.today().strftime("%Y%m%d"))) -``` - -```{code-cell} ipython3 import blackjax import numpy as np import jax.numpy as jnp +from datetime import date +import numpyro +import numpyro.distributions as dist + +from numpyro.infer.util import initialize_model + +rng_key = jax.random.key(int(date.today().strftime("%Y%m%d"))) ``` ```{code-cell} ipython3 -def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform): +def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform, desired_energy_variance= 5e-4): init_key, tune_key, run_key = jax.random.split(key, 3) # create an initial state for the sampler @@ -91,6 +92,7 @@ def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform): state=initial_state, rng_key=tune_key, diagonal_preconditioning=False, + desired_energy_var=desired_energy_variance ) # use the quick wrapper to build a new kernel with the tuned parameters @@ -110,19 +112,24 @@ def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform): progress_bar=True, ) - return samples + return samples, blackjax_state_after_tuning, blackjax_mclmc_sampler_params, run_key ``` ```{code-cell} ipython3 # run the algorithm on a high dimensional gaussian, and show two of the dimensions +logdensity_fn = lambda x: -0.5 * jnp.sum(jnp.square(x)) +num_steps = 10000 +transform = lambda state, info: state.position[:2] + + sample_key, rng_key = jax.random.split(rng_key) -samples = run_mclmc( - logdensity_fn=lambda x: -0.5 * jnp.sum(jnp.square(x)), - num_steps=1000, +samples, initial_state, params, chain_key = run_mclmc( + logdensity_fn=logdensity_fn, + num_steps=num_steps, initial_position=jnp.ones((1000,)), key=sample_key, - transform=lambda state, info: state.position[:2], + transform=transform, ) samples.mean() ``` @@ -133,154 +140,221 @@ plt.axis("equal") plt.title("Scatter Plot of Samples") ``` -# Second example: Stochastic Volatility +# How to analyze the results of your MCLMC run -This is ported from Jakob Robnik's [example notebook](https://github.com/JakobRobnik/MicroCanonicalHMC/blob/master/notebooks/tutorials/advanced_tutorial.ipynb) +## Validate the choice of $\epsilon$ + +A natural sanity check is to see if reducing $\epsilon$ changes the inferred distribution to an extent you care about. For example, let's first inspect the marginal along the first dimension: ```{code-cell} ipython3 -import matplotlib.dates as mdates +def visualize_results_gauss(samples, label, color): + x1 = samples[:, 0] + plt.hist(x1, bins= 30, density= True, histtype= 'step', lw= 4, color= color, label= label) + + +def ground_truth_gauss(): + # ground truth + t= np.linspace(-4, 4, 200) + plt.plot(t, np.exp(-0.5 * np.square(t)) / np.sqrt(2 * np.pi), color= 'black', label= 'exact') + plt.xlabel(r'$x_1$') + plt.ylabel(r'$p(x_1$)') + plt.legend() + plt.show() + +visualize_results_gauss(samples, 'MCLMC', 'teal') +ground_truth_gauss() +``` + +```{code-cell} ipython3 +new_params = params._replace(step_size= params.step_size / 2) +new_num_steps = num_steps * 2 +``` + +```{code-cell} ipython3 +sampling_alg = blackjax.mclmc( + logdensity_fn, + L=new_params.L, + step_size=new_params.step_size, +) + +# run the sampler +_, new_samples = blackjax.util.run_inference_algorithm( + rng_key= chain_key, + initial_state=initial_state, + inference_algorithm=sampling_alg, + num_steps=new_num_steps, + transform=transform, + progress_bar=True, +) + +visualize_results_gauss(new_samples, 'MCLMC', 'red') +visualize_results_gauss(samples, 'MCLMC', 'teal') +ground_truth_gauss() +``` +So here the original $\epsilon$ seems OK. + +## An example where $\epsilon$ is too large + ++++ + +We now consider a more complex model, of stock volatility. + +The returns $r_n$ are modeled by a Student's-t distribution whose scale (volatility) $R_n$ is time varying and unknown. The prior for $\log R_n$ is a Gaussian random walk, with an exponential distribution of the random walk step-size $\sigma$. An exponential prior is also taken for the Student's-t degrees of freedom $\nu$. The generative process of the data is: + +\begin{align} + &r_n / R_n \sim \text{Student's-t}(\nu) \qquad + &&\nu \sim \text{Exp}(\lambda = 1/10) \\ \nonumber + &\log R_n \sim \mathcal{N}(\log R_{n-1}, \sigma) \qquad + &&\sigma \sim \text{Exp}(\lambda = 1/0.02). +\end{align} +Our task is to find the posterior of the parameters $\{R_n\}_{n =1}^N$, $\sigma$ and $\nu$, given the observed data $\{r_n\}_{n =1}^N$. + +First, we get the data, define a model using NumPyro, and draw samples: + +```{code-cell} ipython3 +import matplotlib.dates as mdates from numpyro.examples.datasets import SP500, load_dataset from numpyro.distributions import StudentT # get the data _, fetch = load_dataset(SP500, shuffle=False) SP500_dates, SP500_returns = fetch() +dates = mdates.num2date(mdates.datestr2num(SP500_dates)) -# figure setup -_, ax = plt.subplots(figsize=(12, 5)) -ax.spines["right"].set_visible(False) # remove the upper and the right axis lines -ax.spines["top"].set_visible(False) -ax.xaxis.set_major_locator(mdates.YearLocator()) # dates on the xaxis -ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y")) -ax.xaxis.set_minor_locator(mdates.MonthLocator()) +def setup(): + # figure setup, + plt.figure(figsize = (12, 5)) + ax = plt.subplot() + ax.spines['right'].set_visible(False) #remove the upper and the right axis lines + ax.spines['top'].set_visible(False) -# plot data -dates = mdates.num2date(mdates.datestr2num(SP500_dates)) -ax.plot(dates, SP500_returns, ".", markersize=3, color="steelblue") -ax.set_xlabel("time") -ax.set_ylabel("S&P500 returns") + ax.xaxis.set_major_locator(mdates.YearLocator()) #dates on the xaxis + ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y")) + ax.xaxis.set_minor_locator(mdates.MonthLocator()) + + # plot data + plt.plot(dates, SP500_returns, '.', markersize = 3, color= 'steelblue', label= 'data') + plt.xlabel('time') + plt.ylabel('S&P500 returns') + + +setup() ``` ```{code-cell} ipython3 -dim = 2429 - -lambda_sigma, lambda_nu = 50, 0.1 +def from_numpyro(model, rng_key, model_args): + init_params, potential_fn_gen, *_ = initialize_model( + rng_key, + model, + model_args= model_args, + dynamic_args=True, + ) + logdensity_fn = lambda position: -potential_fn_gen(*model_args)(position) + initial_position = init_params.z -def logp_volatility(x): - """log p of the target distribution""" + return logdensity_fn, initial_position - sigma = ( - jnp.exp(x[-2]) / lambda_sigma - ) # we used log-transformation to make x unconstrained - nu = jnp.exp(x[-1]) / lambda_nu - prior2 = (jnp.exp(x[-2]) - x[-2]) + ( - jnp.exp(x[-1]) - x[-1] - ) # - log prior(sigma, nu) - prior1 = (dim - 2) * jnp.log(sigma) + 0.5 * ( - jnp.square(x[0]) + jnp.sum(jnp.square(x[1:-2] - x[:-3])) - ) / jnp.square( - sigma - ) # - log prior(R) - lik = -jnp.sum( - StudentT(df=nu, scale=jnp.exp(x[:-2])).log_prob(SP500_returns) - ) # - log likelihood +def stochastic_volatility(sigma_mean, nu_mean): + """numpyro model""" + sigma = numpyro.sample("sigma", dist.Exponential(1./sigma_mean)) + nu = numpyro.sample("nu", dist.Exponential(1./nu_mean)) + s = numpyro.sample("s", dist.GaussianRandomWalk(scale=sigma, num_steps=jnp.shape(SP500_returns)[0])) # = log R + numpyro.sample("r", dist.StudentT(df=nu, loc=0.0, scale=jnp.exp(s)), obs= SP500_returns) - return -(lik + prior1 + prior2) +model_args = (0.02, 10.) +rng_key = jax.random.key(42) -def transform(x): - """transform x back to the parameters R, sigma and nu (taking the exponent)""" +logp_sv, x_init = from_numpyro(stochastic_volatility, rng_key, model_args) +``` - Rn = jnp.exp(x[:-2]) - sigma = jnp.exp(x[-2]) / lambda_sigma - nu = jnp.exp(x[-1]) / lambda_nu +```{code-cell} ipython3 +num_steps = 20000 - return jnp.concatenate((Rn, jnp.array([sigma, nu]))) +samples, initial_state, params, chain_key = run_mclmc(logdensity_fn= logp_sv, num_steps= num_steps, initial_position= x_init, key= sample_key, transform=lambda state, info: state.position) +``` +```{code-cell} ipython3 +def visualize_results_sv(samples, color, label): -def prior_draw(key): - """draws x from the prior""" + R = np.exp(np.array(samples['s'])) # take an exponent to get R + lower_quantile, median, upper_quantile = np.quantile(R, [0.25, 0.5, 0.75], axis= 0) - key_walk, key_exp1, key_exp2 = jax.random.split(key, 3) + # plot posterior + plt.plot(dates, median, color= color, label = label) + plt.fill_between(dates, lower_quantile, upper_quantile, color= color, alpha=0.5) - sigma = ( - jax.random.exponential(key_exp1) / lambda_sigma - ) # sigma is drawn from the exponential distribution - def step(track, useless): # one step of the gaussian random walk - randkey, subkey = jax.random.split(track[1]) - x = ( - jax.random.normal(subkey, shape=track[0].shape, dtype=track[0].dtype) - + track[0] - ) - return (x, randkey), x +setup() - x = jnp.empty(dim) - x = x.at[:-2].set( - jax.lax.scan(step, init=(0.0, key_walk), xs=None, length=dim - 2)[1] * sigma - ) # = log R_n are drawn as a Gaussian random walk realization - x = x.at[-2].set( - jnp.log(sigma * lambda_sigma) - ) # sigma ~ exponential distribution(lambda_sigma) - x = x.at[-1].set( - jnp.log(jax.random.exponential(key_exp2)) - ) # nu ~ exponential distribution(lambda_nu) +visualize_results_sv(samples, color= 'navy', label= 'volatility posterior') - return x +plt.legend() +plt.show() ``` ```{code-cell} ipython3 -key1, key2, rng_key = jax.random.split(rng_key, 3) -samples = run_mclmc( - logdensity_fn=logp_volatility, - num_steps=10000, - initial_position=prior_draw(key1), - key=key2, - transform=lambda state, info: state, +new_params = params._replace(step_size = params.step_size/2) +new_num_steps = num_steps * 2 + +sampling_alg = blackjax.mclmc( + logp_sv, + L=new_params.L, + step_size=new_params.step_size, +) + + + +# # run the sampler +_, new_samples = blackjax.util.run_inference_algorithm( + rng_key=chain_key, + initial_state=initial_state, + inference_algorithm=sampling_alg, + num_steps=new_num_steps, + transform=lambda state, info : state.position, + progress_bar=True, ) -samples = transform(samples.position) ``` ```{code-cell} ipython3 -R = np.array(samples)[:, :-2] # remove sigma and nu parameters -R = np.sort(R, axis=0) # sort samples for each R_n -num_samples = len(R) -lower_quartile, median, upper_quartile = ( - R[num_samples // 4, :], - R[num_samples // 2, :], - R[3 * num_samples // 4, :], -) +setup() +visualize_results_sv(new_samples,'red', 'MCLMC', ) +visualize_results_sv(samples,'teal', 'MCLMC (stepsize/2)', ) -# figure setup -_, ax = plt.subplots(figsize=(12, 5)) -ax.spines["right"].set_visible(False) # remove the upper and the right axis lines -ax.spines["top"].set_visible(False) +plt.legend() +plt.show() +``` -ax.xaxis.set_major_locator(mdates.YearLocator()) # dates on the xaxis -ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y")) -ax.xaxis.set_minor_locator(mdates.MonthLocator()) +This looks OK, but if we inspect the hierarchical parameter's marginal alone, we see something different. -# plot data -ax.plot(dates, SP500_returns, ".", markersize=3, color="steelblue") -ax.plot( - [], [], ".", markersize=10, color="steelblue", alpha=0.5, label="data" -) # larger markersize for the legend -ax.set_xlabel("time") -ax.set_ylabel("S&P500 returns") +```{code-cell} ipython3 +def visualize_results_sv_marginal(samples, color, label): + # plt.subplot(1, 2, 1) + # plt.hist(samples['nu'], bins = 20, histtype= 'step', lw= 4, density= True, color= color, label= label) + # plt.xlabel(r'$\nu$') + # plt.ylabel(r'$p(\nu \vert \mathrm{data})$') -# plot posterior -ax.plot(dates, median, color="navy", label="volatility posterior") -ax.fill_between(dates, lower_quartile, upper_quartile, color="navy", alpha=0.5) + plt.subplot(1, 2, 2) + plt.hist(samples['sigma'], bins = 20, histtype= 'step', lw= 4, density= True, color= color, label= label) + plt.xlabel(r'$\sigma$') + plt.ylabel(r'$p(\sigma \vert \mathrm{data})$') -ax.legend() + +plt.figure(figsize = (10, 4)) +visualize_results_sv_marginal(samples, color= 'teal', label= 'MCLMC') +visualize_results_sv_marginal(new_samples, color= 'red', label= 'MCLMC (stepsize/2)') ``` +In this case, we should reduce step size further, until the difference disappears. + ++++ + ```{bibliography} :filter: docname in docnames ```