diff --git a/docs/source/mcmc.rst b/docs/source/mcmc.rst index edfda92d2..f89d7f04c 100644 --- a/docs/source/mcmc.rst +++ b/docs/source/mcmc.rst @@ -9,7 +9,9 @@ We provide a high-level overview of the MCMC algorithms in NumPyro: * `BarkerMH `_ is a gradient-based MCMC method that may be competitive with HMC and NUTS for some models. It is applicable to models with continuous latent variables. * `HMCGibbs `_ combines HMC/NUTS steps with custom Gibbs updates. Gibbs updates must be specified by the user. * `DiscreteHMCGibbs `_ combines HMC/NUTS steps with Gibbs updates for discrete latent variables. The corresponding Gibbs updates are computed automatically. -* `SA `_ is the only MCMC method in NumPyro that does not leverage gradients. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities. Note that SA generally requires a *very* large number of samples, as mixing tends to be slow. On the plus side individual steps can be fast. +* `SA `_ is a gradient-free MCMC method. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities. Note that SA generally requires a *very* large number of samples, as mixing tends to be slow. On the plus side individual steps can be fast. +* `AIES `_ is a gradient-free ensemble MCMC method that informs Metropolis-Hastings proposals by sharing information between chains. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities, and can be robust to likelihood-free models. AIES generally requires the number of chains to be twice as large as the number of latent parameters, (and ideally larger). +* `ESS `_ is a gradient-free ensemble MCMC method that shares information between chains to find good slice sampling directions. It tends to be more sample efficient than AIES. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate and may be a good choice for models with non-differentiable log densities. ESS generally requires the number of chains to be twice as large as the number of latent parameters, (and ideally larger). Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete latent variables if possible (see `restrictions `_). Enumerated sites need to be marked with `infer={'enumerate': 'parallel'}` like in the `annotation example `_. @@ -101,6 +103,30 @@ SA :show-inheritance: :member-order: bysource +EnsembleSampler +^^^^^^^^^^^^^^^ +.. autoclass:: numpyro.infer.ensemble.EnsembleSampler + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + +AIES +^^^^ +.. autoclass:: numpyro.infer.ensemble.AIES + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + +ESS +^^^ +.. autoclass:: numpyro.infer.ensemble.ESS + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + .. autofunction:: numpyro.infer.hmc.hmc .. autofunction:: numpyro.infer.hmc.hmc.init_kernel @@ -117,6 +143,12 @@ SA .. autodata:: numpyro.infer.sa.SAState +.. autodata:: numpyro.infer.ensemble.EnsembleSamplerState + +.. autodata:: numpyro.infer.ensemble.AIESState + +.. autodata:: numpyro.infer.ensemble.ESSState + TensorFlow Kernels ------------------ diff --git a/numpyro/infer/__init__.py b/numpyro/infer/__init__.py index d9e0e337f..9abf96fa2 100644 --- a/numpyro/infer/__init__.py +++ b/numpyro/infer/__init__.py @@ -10,6 +10,7 @@ TraceGraph_ELBO, TraceMeanField_ELBO, ) +from numpyro.infer.ensemble import AIES, ESS from numpyro.infer.hmc import HMC, NUTS from numpyro.infer.hmc_gibbs import HMCECS, DiscreteHMCGibbs, HMCGibbs from numpyro.infer.initialization import ( @@ -29,6 +30,7 @@ from . import autoguide, reparam __all__ = [ + "AIES", "autoguide", "init_to_feasible", "init_to_mean", @@ -41,6 +43,7 @@ "BarkerMH", "DiscreteHMCGibbs", "ELBO", + "ESS", "HMC", "HMCECS", "HMCGibbs", diff --git a/numpyro/infer/ensemble.py b/numpyro/infer/ensemble.py new file mode 100644 index 000000000..5d7a9c119 --- /dev/null +++ b/numpyro/infer/ensemble.py @@ -0,0 +1,790 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from collections import namedtuple +import warnings + +import jax +from jax import random, vmap +import jax.numpy as jnp +from jax.scipy.stats import gaussian_kde + +import numpyro.distributions as dist +from numpyro.infer.ensemble_util import batch_ravel_pytree, get_nondiagonal_indices +from numpyro.infer.initialization import init_to_uniform +from numpyro.infer.mcmc import MCMCKernel +from numpyro.infer.util import initialize_model +from numpyro.util import identity, is_prng_key + +EnsembleSamplerState = namedtuple( + "EnsembleSamplerState", ["z", "inner_state", "rng_key"] +) +""" +A :func:`~collections.namedtuple` consisting of the following fields: + + - **z** - Python collection representing values (unconstrained samples from + the posterior) at latent sites. + - **inner_state** - A namedtuple containing information needed to update half the ensemble. + - **rng_key** - random number generator seed used for generating proposals, etc. +""" + +AIESState = namedtuple("AIESState", ["i", "accept_prob", "mean_accept_prob", "rng_key"]) +""" +A :func:`~collections.namedtuple` consisting of the following fields. + + - **i** - iteration. + - **accept_prob** - Acceptance probability of the proposal. Note that ``z`` + does not correspond to the proposal if it is rejected. + - **mean_accept_prob** - Mean acceptance probability until current iteration + during warmup adaptation or sampling (for diagnostics). + - **rng_key** - random number generator seed used for generating proposals, etc. +""" + +ESSState = namedtuple("ESSState", ["i", + "n_expansions", + "n_contractions", + "mu", + "rng_key" + ] + ) +""" +A :func:`~collections.namedtuple` used as an inner state for Ensemble Sampler. +This consists of the following fields: + + - **i** - iteration. + - **n_expansions** - number of expansions in the current batch. Used for tuning mu. + - **n_contractions** - number of contractions in the current batch. Used for tuning mu. + - **mu** - Scale factor. This is tuned if tune_mu=True. + - **rng_key** - random number generator seed used for generating proposals, etc. +""" + + +class EnsembleSampler(MCMCKernel, ABC): + """ + Abstract class for ensemble samplers. Each MCMC sample is divided into two sub-iterations + in which half of the ensemble is updated. + + :param model: Python callable containing Pyro :mod:`~numpyro.primitives`. + If model is provided, `potential_fn` will be inferred using the model. + :param potential_fn: Python callable that computes the potential energy + given input parameters. The input parameters to `potential_fn` can be + any python collection type, provided that `init_params` argument to + :meth:`init` has the same type. + :param bool randomize_split: whether or not to permute the chain order at each iteration. + :param callable init_strategy: a per-site initialization function. + See :ref:`init_strategy` section for available functions. + """ + + def __init__(self, model=None, potential_fn=None, *, randomize_split, init_strategy): + if not (model is None) ^ (potential_fn is None): + raise ValueError("Only one of `model` or `potential_fn` must be specified.") + + self._model = model + self._potential_fn = potential_fn + self._batch_log_density = None + # unravel an (n_chains, n_params) Array into a pytree and + # evaluate the log density at each chain + + # --- other hyperparams go here + self._num_chains = None # must be an even number >= 2 + self._randomize_split = randomize_split + # --- + + self._init_strategy = init_strategy + self._postprocess_fn = None + + @property + def model(self): + return self._model + + @property + def sample_field(self): + return "z" + + @property + def is_ensemble_kernel(self): + return True + + @abstractmethod + def init_inner_state(self, rng_key): + """return inner_state""" + raise NotImplementedError + + @abstractmethod + def update_active_chains(self, active, inactive, inner_state): + """return (updated active set of chains, updated inner state)""" + raise NotImplementedError + + def _init_state(self, rng_key, model_args, model_kwargs, init_params): + if self._model is not None: + new_params_info, potential_fn_gen, self._postprocess_fn, _ = initialize_model( + rng_key, + self._model, + dynamic_args=True, + init_strategy=self._init_strategy, + model_args=model_args, + model_kwargs=model_kwargs, + validate_grad=False, + ) + new_init_params = new_params_info[0] + self._potential_fn = potential_fn_gen(*model_args, **model_kwargs) + + if init_params is None: + init_params = new_init_params + + flat_params, unravel_fn = batch_ravel_pytree(init_params) + self._batch_log_density = lambda z: -vmap(self._potential_fn)(unravel_fn(z)) + + if self._num_chains < 2 * flat_params.shape[1]: + warnings.warn("Setting n_chains to at least 2*n_params is strongly recommended.\n" + f"n_chains: {self._num_chains}, n_params: {flat_params.shape[1]}") + + return init_params + + def init( + self, rng_key, num_warmup, init_params=None, model_args=(), model_kwargs={} + ): + assert not is_prng_key( + rng_key + ), ("EnsembleSampler only supports chain_method='vectorized' with num_chains > 1.\n" + "If you want to run chains in parallel, please raise a github issue.") + + assert rng_key.shape[0] % 2 == 0, "Number of chains must be even." + + self._num_chains = rng_key.shape[0] + + if self._potential_fn and init_params is None: + raise ValueError( + "Valid value of `init_params` must be provided with `potential_fn`." + ) + if init_params is not None: + assert all([param.shape[0] == self._num_chains + for param in jax.tree_util.tree_leaves(init_params)]), ( + "The batch dimension of each param must match n_chains") + + rng_key, rng_key_inner_state, rng_key_init_model = random.split(rng_key[0], 3) + rng_key_init_model = random.split(rng_key_init_model, self._num_chains) + + init_params = self._init_state( + rng_key_init_model, model_args, model_kwargs, init_params + ) + + self._num_warmup = num_warmup + + return EnsembleSamplerState( + init_params, self.init_inner_state(rng_key_inner_state), rng_key + ) + + def postprocess_fn(self, args, kwargs): + if self._postprocess_fn is None: + return identity + return self._postprocess_fn(*args, **kwargs) + + def sample(self, state, model_args, model_kwargs): + z, inner_state, rng_key = state + rng_key, _ = random.split(rng_key) + z_flat, unravel_fn = batch_ravel_pytree(z) + + if self._randomize_split: + z_flat = random.permutation(rng_key, z_flat, axis=0) + + split_ind = self._num_chains // 2 + + def body_fn(i, z_flat_inner_state): + z_flat, inner_state = z_flat_inner_state + + active, inactive = jax.lax.cond(i == 0, + lambda x: (x[:split_ind], x[split_ind:]), + lambda x: (x[split_ind:], x[split_ind:]), + z_flat) + + z_updates, inner_state = self.update_active_chains(active, inactive, inner_state) + + z_flat = jax.lax.cond(i == 0, + lambda x: x.at[:split_ind].set(z_updates), + lambda x: x.at[split_ind:].set(z_updates), + z_flat) + return (z_flat, inner_state) + + z_flat, inner_state = jax.lax.fori_loop(0, 2, body_fn, (z_flat, inner_state)) + + + return EnsembleSamplerState(unravel_fn(z_flat), inner_state, rng_key) + + +class AIES(EnsembleSampler): + """ + Affine-Invariant Ensemble Sampling: a gradient free method that informs Metropolis-Hastings + proposals by sharing information between chains. Suitable for low to moderate dimensional models. + Generally, `num_chains` should be at least twice the dimensionality of the model. + + .. note:: This kernel must be used with `num_chains` > 1 and `chain_method="vectorized` + in :class:`MCMC`. The number of chains must be divisible by 2. + + **References:** + + 1. *emcee: The MCMC Hammer* (https://iopscience.iop.org/article/10.1086/670067), + Daniel Foreman-Mackey, David W. Hogg, Dustin Lang, and Jonathan Goodman. + + :param model: Python callable containing Pyro :mod:`~numpyro.primitives`. + If model is provided, `potential_fn` will be inferred using the model. + :param potential_fn: Python callable that computes the potential energy + given input parameters. The input parameters to `potential_fn` can be + any python collection type, provided that `init_params` argument to + :meth:`init` has the same type. + :param bool randomize_split: whether or not to permute the chain order at each iteration. + Defaults to False. + :param moves: a dictionary mapping moves to their respective probabilities of being selected. + Valid keys are `AIES.DEMove()` and `AIES.StretchMove()`. Both tend to work well in practice. + If the sum of probabilites exceeds 1, the probabilities will be normalized. Defaults to `{AIES.DEMove(): 1.0}`. + :param callable init_strategy: a per-site initialization function. + See :ref:`init_strategy` section for available functions. + + **Example** + + .. doctest:: + + >>> import jax + >>> import jax.numpy as jnp + >>> import numpyro + >>> import numpyro.distributions as dist + >>> from numpyro.infer import MCMC, AIES + + >>> def model(): + ... x = numpyro.sample("x", dist.Normal().expand([10])) + ... numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10)) + >>> + >>> kernel = AIES(model, moves={AIES.DEMove() : 0.5, + ... AIES.StretchMove() : 0.5}) + >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=20, chain_method='vectorized') + >>> mcmc.run(jax.random.PRNGKey(0)) + """ + + def __init__(self, model=None, potential_fn=None, randomize_split=False, moves=None, init_strategy=init_to_uniform): + if not moves: + self._moves = [AIES.DEMove()] + self._weights = jnp.array([1.0]) + else: + self._moves = list(moves.keys()) + self._weights = jnp.array([weight for weight in moves.values()]) / len(moves) + + assert all([hasattr(move, '__call__') for move in self._moves]), ( + "Each move must be a callable (one of AIES.DEMove(), or AIES.StretchMove()).") + assert jnp.all(self._weights >= 0), "Each specified move must have probability >= 0" + + super().__init__(model, + potential_fn, + randomize_split=randomize_split, + init_strategy=init_strategy) + + def get_diagnostics_str(self, state): + return "acc. prob={:.2f}".format(state.inner_state.mean_accept_prob) + + def init_inner_state(self, rng_key): + # XXX hack -- we don't know num_chains until we init the inner state + self._moves = [move(self._num_chains) if move.__name__ == 'make_de_move' + else move for move in self._moves] + + return AIESState(jnp.array(0.0), jnp.array(0.0), jnp.array(0.0), rng_key) + + def update_active_chains(self, active, inactive, inner_state): + i, _, mean_accept_prob, rng_key = inner_state + rng_key, move_key, proposal_key, accept_key = random.split(rng_key, 4) + + move_i = random.choice(move_key, len(self._moves), p=self._weights) + proposal, factors = jax.lax.switch( + move_i, self._moves, proposal_key, active, inactive + ) + + # --- evaluate the proposal --- + log_accept_prob = ( + factors + + self._batch_log_density(proposal) + - self._batch_log_density(active) + ) + + accepted = dist.Uniform().sample(accept_key, (active.shape[0],)) < jnp.exp( + log_accept_prob + ) + updated_active_chains = jnp.where(accepted[:, jnp.newaxis], proposal, active) + + accept_prob = jnp.count_nonzero(accepted) / accepted.shape[0] + itr = i + 0.5 + n = jnp.where(i < self._num_warmup, itr, itr - self._num_warmup) + mean_accept_prob = mean_accept_prob + (accept_prob - mean_accept_prob) / n + + return updated_active_chains, AIESState( + itr, accept_prob, mean_accept_prob, rng_key + ) + + @staticmethod + def DEMove(sigma=1.0e-5, g0=None): + """ + A proposal using differential evolution. + + This `Differential evolution proposal + `_ is + implemented following `Nelson et al. (2013) + `_. + + :param sigma: (optional) + The standard deviation of the Gaussian used to stretch the proposal vector. + Defaults to `1.0.e-5`. + :param g0 (optional): + The mean stretch factor for the proposal vector. By default, + it is `2.38 / sqrt(2*ndim)` as recommended by the two references. + """ + def make_de_move(n_chains): + PAIRS = get_nondiagonal_indices(n_chains // 2) + + def de_move(rng_key, active, inactive): + pairs_key, gamma_key = random.split(rng_key) + n_active_chains, n_params = inactive.shape + + # XXX: if we pass in n_params to parent scope we don't need to + # recompute this each time + g = 2.38 / jnp.sqrt(2.0 * n_params) if not g0 else g0 + + selected_pairs = random.choice(pairs_key, PAIRS, shape=(n_active_chains,)) + + # Compute diff vectors + diffs = jnp.diff(inactive[selected_pairs], axis=1).squeeze(axis=1) + + # Sample a gamma value for each walker following Nelson et al. (2013) + gamma = dist.Normal(g, g * sigma).sample( + gamma_key, sample_shape=(n_active_chains, 1) + ) + + # In this way, sigma is the standard deviation of the distribution of gamma, + # instead of the standard deviation of the distribution of the proposal as proposed by Ter Braak (2006). + # Otherwise, sigma should be tuned for each dimension, which confronts the idea of affine-invariance. + proposal = active + gamma * diffs + + return proposal, jnp.zeros(n_active_chains) + + return de_move + + return make_de_move + + @staticmethod + def StretchMove(a=2.0): + """ + A `Goodman & Weare (2010) + `_ "stretch move" with + parallelization as described in `Foreman-Mackey et al. (2013) + `_. + + :param a: (optional) + The stretch scale parameter. (default: ``2.0``) + """ + def stretch_move(rng_key, active, inactive): + n_active_chains, n_params = active.shape + unif_key, idx_key = random.split(rng_key) + + zz = ( + (a - 1.0) * random.uniform(unif_key, shape=(n_active_chains,)) + 1 + ) ** 2.0 / a + factors = (n_params - 1.0) * jnp.log(zz) + r_idxs = random.randint( + idx_key, shape=(n_active_chains,), minval=0, maxval=n_active_chains + ) + + proposal = inactive[r_idxs] - (inactive[r_idxs] - active) * zz[:, jnp.newaxis] + + return proposal, factors + + return stretch_move + + +class ESS(EnsembleSampler): + """ + Ensemble Slice Sampling: a gradient free method that finds better slice sampling directions + by sharing information between chains. Suitable for low to moderate dimensional models. + Generally, `num_chains` should be at least twice the dimensionality of the model. + + .. note:: This kernel must be used with `num_chains` > 1 and `chain_method="vectorized` + in :class:`MCMC`. The number of chains must be divisible by 2. + + **References:** + + 1. *zeus: a PYTHON implementation of ensemble slice sampling for efficient Bayesian parameter inference* (https://academic.oup.com/mnras/article/508/3/3589/6381726), + Minas Karamanis, Florian Beutler, and John A. Peacock. + 2. *Ensemble slice sampling* (https://link.springer.com/article/10.1007/s11222-021-10038-2), + Minas Karamanis, Florian Beutler. + + :param model: Python callable containing Pyro :mod:`~numpyro.primitives`. + If model is provided, `potential_fn` will be inferred using the model. + :param potential_fn: Python callable that computes the potential energy + given input parameters. The input parameters to `potential_fn` can be + any python collection type, provided that `init_params` argument to + :meth:`init` has the same type. + :param bool randomize_split: whether or not to permute the chain order at each iteration. + Defaults to True. + :param moves: a dictionary mapping moves to their respective probabilities of being selected. + If the sum of probabilites exceeds 1, the probabilities will be normalized. Valid keys include: + `ESS.DifferentialMove()` -> default proposal, works well along a wide range of target distributions, + `ESS.GaussianMove()` -> for approximately normally distributed targets, + `ESS.KDEMove()` -> for multimodal posteriors - requires large `num_chains`, and they must be well initialized + `ESS.RandomMove()` -> no chain interaction, useful for debugging. + Defaults to `{ESS.DifferentialMove(): 1.0}`. + + :param int max_steps: number of maximum stepping-out steps per sample. Defaults to 10,000. + :param int max_iter: number of maximum expansions/contractions per sample. Defaults to 10,000. + :param float init_mu: initial scale factor. Defaults to 1.0. + :param bool tune_mu: whether or not to tune the initial scale factor. Defaults to True. + :param callable init_strategy: a per-site initialization function. + See :ref:`init_strategy` section for available functions. + + **Example** + + .. doctest:: + + >>> import jax + >>> import jax.numpy as jnp + >>> import numpyro + >>> import numpyro.distributions as dist + >>> from numpyro.infer import MCMC, ESS + + >>> def model(): + ... x = numpyro.sample("x", dist.Normal().expand([10])) + ... numpyro.sample("obs", dist.Normal(x, 1.0), obs=jnp.ones(10)) + >>> + >>> kernel = ESS(model, moves={ESS.DifferentialMove() : 0.8, + ... ESS.RandomMove() : 0.2}) + >>> mcmc = MCMC(kernel, num_warmup=1000, num_samples=2000, num_chains=20, chain_method='vectorized') + >>> mcmc.run(jax.random.PRNGKey(0)) + """ + def __init__( + self, + model=None, + potential_fn=None, + randomize_split=True, + moves=None, + max_steps=10_000, + max_iter=10_000, + init_mu=1.0, + tune_mu=True, + init_strategy=init_to_uniform, + ): + if not moves: + self._moves = [ESS.DifferentialMove()] + self._weights = jnp.array([1.0]) + else: + self._moves = list(moves.keys()) + self._weights = jnp.array([weight for weight in moves.values()]) / len(moves) + + assert all([hasattr(move, '__call__') for move in self._moves]), ( + "Each move must be a callable (one of `ESS.DifferentialMove()`, " + "`ESS.GaussianMove()`, `ESS.KDEMove()`, `ESS.RandomMove()`)") + + assert jnp.all(self._weights >= 0), "Each specified move must have probability >= 0" + assert init_mu > 0, "Scale factor should be strictly positive" + + self._max_steps = max_steps # max number of stepping out steps + self._max_iter = max_iter # max number of expansions/contractions + self._init_mu = init_mu + self._tune_mu = tune_mu + + super().__init__(model, + potential_fn, + randomize_split=randomize_split, + init_strategy=init_strategy) + + def init_inner_state(self, rng_key): + self.batch_log_density = lambda x: self._batch_log_density(x)[:, jnp.newaxis] + + # XXX hack -- we don't know num_chains until we init the inner state + self._moves = [move(self._num_chains) if move.__name__ == 'make_differential_move' + else move for move in self._moves] + + return ESSState(jnp.array(0.0), jnp.array(0), jnp.array(0), self._init_mu, rng_key) + + def update_active_chains(self, active, inactive, inner_state): + i, n_expansions, n_contractions, mu, rng_key = inner_state + (rng_key, + move_key, + dir_key, + height_key, + step_out_key, + shrink_key) = random.split(rng_key, 6) + + n_active_chains, n_params = active.shape + + move_i = random.choice(move_key, len(self._moves), p=self._weights) + directions = jax.lax.switch(move_i, self._moves, dir_key, inactive, mu) + + log_slice_height = self.batch_log_density(active) - dist.Exponential().sample( + height_key, sample_shape=(n_active_chains, 1) + ) + + curr_n_expansions, L, R = self._step_out( + step_out_key, log_slice_height, active, directions + ) + proposal, curr_n_contractions = self._shrink( + shrink_key, log_slice_height, L, R, active, directions + ) + + n_expansions += curr_n_expansions + n_contractions += curr_n_contractions + itr = i + 0.5 + + if self._tune_mu: + safe_n_expansions = jnp.max(jnp.array([1, n_expansions])) + + # only update tuning scale if a full iteration has passed + mu, n_expansions, n_contractions = jax.lax.cond(jnp.all(itr % 1 == 0), + lambda n_exp, n_con: (2.0 * n_exp / (n_exp + n_con), + jnp.array(0), + jnp.array(0) + ), + lambda _, __: (mu, + n_expansions, + n_contractions + ), + safe_n_expansions, n_contractions) + + return proposal, ESSState(itr, n_expansions, n_contractions, mu, rng_key) + + + @staticmethod + def RandomMove(): + """ + The `Karamanis & Beutler (2020) `_ "Random Move" with parallelization. + When this move is used the walkers move along random directions. There is no communication between the + walkers and this Move corresponds to the vanilla Slice Sampling method. This Move should be used for + debugging purposes only. + """ + def random_move(rng_key, inactive, mu): + directions = dist.Normal(loc=0, scale=1).sample( + rng_key, sample_shape=inactive.shape + ) + directions /= jnp.linalg.norm(directions, axis=0) + + return 2.0 * mu * directions + return random_move + + @staticmethod + def KDEMove(bw_method=None): + """ + The `Karamanis & Beutler (2020) `_ "KDE Move" with parallelization. + When this Move is used the distribution of the walkers of the complementary ensemble is traced using + a Gaussian Kernel Density Estimation methods. The walkers then move along random direction vectos + sampled from this distribution. + """ + def kde_move(rng_key, inactive, mu): + n_active_chains, n_params = inactive.shape + + kde = gaussian_kde(inactive.T, bw_method=bw_method) + + vectors = kde.resample(rng_key, (2 * n_active_chains,)).T + directions = vectors[:n_active_chains] - vectors[n_active_chains:] + + return 2.0 * mu * directions + return kde_move + + @staticmethod + def GaussianMove(): + """ + The `Karamanis & Beutler (2020) `_ "Gaussian Move" with parallelization. + When this Move is used the walkers move along directions defined by random vectors sampled from the Gaussian + approximation of the walkers of the complementary ensemble. + """ + + # In high dimensional regimes with sufficiently small n_active_chains, + # it is more efficient to sample without computing the Cholesky + # decomposition of the covariance matrix: + + # eps = dist.Normal(0, 1).sample(rng_key, (n_active_chains, n_active_chains)) + # return 2.0 * mu * (eps @ (inactive - jnp.mean(inactive, axis=0)) / jnp.sqrt(n_active_chains)) + + def gaussian_move(rng_key, inactive, mu): + n_active_chains, n_params = inactive.shape + cov = jnp.cov(inactive, rowvar=False) + + return ( + 2.0 + * mu + * dist.MultivariateNormal(0, cov).sample( + rng_key, sample_shape=(n_active_chains,) + ) + ) + return gaussian_move + + @staticmethod + def DifferentialMove(): + """ + The `Karamanis & Beutler (2020) `_ "Differential Move" with parallelization. + When this Move is used the walkers move along directions defined by random pairs of walkers sampled (with no + replacement) from the complementary ensemble. This is the default choice and performs well along a wide range + of target distributions. + """ + def make_differential_move(n_chains): + PAIRS = get_nondiagonal_indices(n_chains // 2) + + def differential_move(rng_key, inactive, mu): + n_active_chains, n_params = inactive.shape + + selected_pairs = random.choice(rng_key, PAIRS, shape=(n_active_chains,)) + diffs = jnp.diff(inactive[selected_pairs], axis=1).squeeze( + axis=1 + ) # get the pairwise difference of each vector + + return 2.0 * mu * diffs + return differential_move + + return make_differential_move + + + def _step_out(self, rng_key, log_slice_height, active, directions): + init_L_key, init_J_key = random.split(rng_key) + n_active_chains, n_params = active.shape + + iteration = 0 + n_expansions = 0 + # set initial interval boundaries + L = -dist.Uniform().sample(init_L_key, sample_shape=(n_active_chains, 1)) + R = L + 1.0 + + # stepping out + J = jnp.floor( + dist.Uniform(low=0, high=self._max_steps).sample( + init_J_key, sample_shape=(n_active_chains, 1) + ) + ) + K = (self._max_steps - 1) - J + + # left stepping-out initialisation + mask_J = jnp.full((n_active_chains, 1), True) + # right stepping-out initialisation + mask_K = jnp.full((n_active_chains, 1), True) + + init_values = (n_expansions, L, R, J, K, mask_J, mask_K, iteration) + + def cond_fn(args): + n_expansions, L, R, J, K, mask_J, mask_K, iteration = args + + return (jnp.count_nonzero(mask_J) + jnp.count_nonzero(mask_K) > 0) & ( + iteration < self._max_iter + ) + + def body_fn(args): + n_expansions, L, R, J, K, mask_J, mask_K, iteration = args + + log_prob_L = self.batch_log_density(directions * L + active) + log_prob_R = self.batch_log_density(directions * R + active) + + can_expand_L = log_prob_L > log_slice_height + L = jnp.where(can_expand_L, L - 1, L) + J = jnp.where(can_expand_L, J - 1, J) + mask_J = jnp.where(can_expand_L, mask_J, False) + + can_expand_R = log_prob_R > log_slice_height + R = jnp.where(can_expand_R, R + 1, R) + K = jnp.where(can_expand_R, K - 1, K) + mask_K = jnp.where(can_expand_R, mask_K, False) + + iteration += 1 + n_expansions += jnp.count_nonzero(can_expand_L) + jnp.count_nonzero( + can_expand_R + ) + + return (n_expansions, L, R, J, K, mask_J, mask_K, iteration) + + n_expansions, L, R, J, K, mask_J, mask_K, iteration = jax.lax.while_loop( + cond_fn, body_fn, init_values + ) + + return n_expansions, L, R + + def _shrink(self, rng_key, log_slice_height, L, R, active, directions): + n_active_chains, n_params = active.shape + + iteration = 0 + n_contractions = 0 + widths = jnp.zeros((n_active_chains, 1)) + proposed = jnp.zeros((n_active_chains, n_params)) + can_shrink = jnp.full((n_active_chains, 1), True) + + init_values = ( + rng_key, + proposed, + n_contractions, + L, + R, + widths, + can_shrink, + iteration, + ) + + def cond_fn(args): + ( + rng_key, + proposed, + n_contractions, + L, + R, + widths, + can_shrink, + iteration, + ) = args + + return (jnp.count_nonzero(can_shrink) > 0) & (iteration < self._max_iter) + + def body_fn(args): + ( + rng_key, + proposed, + n_contractions, + L, + R, + widths, + can_shrink, + iteration, + ) = args + + rng_key, _ = random.split(rng_key) + + widths = jnp.where( + can_shrink, dist.Uniform(low=L, high=R).sample(rng_key), widths + ) + + # compute new positions + proposed = jnp.where(can_shrink, directions * widths + active, proposed) + proposed_log_prob = self.batch_log_density(proposed) + + # shrink slices + can_shrink = proposed_log_prob < log_slice_height + + L_cond = can_shrink & (widths < 0.0) + L = jnp.where(L_cond, widths, L) + + R_cond = can_shrink & (widths > 0.0) + R = jnp.where(R_cond, widths, R) + + iteration += 1 + n_contractions += jnp.count_nonzero(L_cond) + jnp.count_nonzero(R_cond) + + return ( + rng_key, + proposed, + n_contractions, + L, + R, + widths, + can_shrink, + iteration, + ) + + ( + rng_key, + proposed, + n_contractions, + L, + R, + widths, + can_shrink, + iteration, + ) = jax.lax.while_loop(cond_fn, body_fn, init_values) + + return proposed, n_contractions diff --git a/numpyro/infer/ensemble_util.py b/numpyro/infer/ensemble_util.py new file mode 100644 index 000000000..028d694a5 --- /dev/null +++ b/numpyro/infer/ensemble_util.py @@ -0,0 +1,47 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np + +import jax +from jax.flatten_util import ravel_pytree +import jax.numpy as jnp +from jax.tree_util import tree_map + + +def get_nondiagonal_indices(n): + """ + From https://github.com/dfm/emcee/blob/main/src/emcee/moves/de.py: + + Get the indices of a square matrix with size n, excluding the diagonal. + """ + rows, cols = np.tril_indices(n, -1) # -1 to exclude diagonal + + # Combine rows-cols and cols-rows pairs + pairs = np.column_stack([np.concatenate([rows, cols]), + np.concatenate([cols, rows])]) + + return jnp.asarray(pairs) + + +def batch_ravel_pytree(pytree): + """ + Ravel (flatten) a pytree of arrays with leading batch dimension down to a (batch_size, 1D) array. + + Args: + pytree: a pytree of arrays and scalars to ravel. + Returns: + A pair where the first element is a (batch_size, 1D) array representing the flattened and + concatenated leaf values, with dtype determined by promoting the dtypes of + leaf values, and the second element is a callable for unflattening a (batch_size, 1D) + array of the same length back to a pytree of the same structure as the + input ``pytree``. If the input pytree is empty (i.e. has no leaves) then as + a convention a 1D empty array of dtype float32 is returned in the first + component of the output. + """ + flat = jax.vmap(lambda x: ravel_pytree(x)[0])(pytree) + unravel_fn = jax.vmap(ravel_pytree(tree_map(lambda z: z[0], pytree))[1]) + + return flat, unravel_fn + + diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index b4ac1ec4a..8ffd38871 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -139,6 +139,16 @@ def default_fields(self): """ return (self.sample_field,) + @property + def is_ensemble_kernel(self): + """ + Denotes whether the kernel is an ensemble kernel. If True, + diagnostics_str will be displayed during the MCMC run + (when :meth:`MCMC.run() ` is called) + if `chain_method` = "vectorized". + """ + return False + def get_diagnostics_str(self, state): """ Given the current `state`, returns the diagnostics string to @@ -424,7 +434,7 @@ def _single_chain_mcmc(self, init, args, kwargs, collect_fields): sample_fn, postprocess_fn = self._get_cached_fns() diagnostics = ( # noqa: E731 lambda x: self.sampler.get_diagnostics_str(x[0]) - if is_prng_key(rng_key) + if is_prng_key(rng_key) or self.sampler.is_ensemble_kernel else "" ) init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,) diff --git a/test/infer/test_ensemble_mcmc.py b/test/infer/test_ensemble_mcmc.py new file mode 100644 index 000000000..7cc289d89 --- /dev/null +++ b/test/infer/test_ensemble_mcmc.py @@ -0,0 +1,61 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +import jax.numpy as jnp +import jax.random as random + +import numpyro +import numpyro.distributions as dist +from numpyro.infer import AIES, ESS, MCMC + +numpyro.set_host_device_count(2) +# --- +# reused for all smoke-tests +N, dim = 3000, 3 + +data = random.normal(random.PRNGKey(0), (N, dim)) +true_coefs = jnp.arange(1.0, dim + 1.0) +logits = jnp.sum(true_coefs * data, axis=-1) +labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1)) + +def model(labels): + coefs = numpyro.sample("coefs", dist.Normal(jnp.zeros(dim), jnp.ones(dim))) + logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1)) + return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) +# --- + +@pytest.mark.parametrize("kernel_cls, n_chain, method", + [(AIES, 10, "sequential"), + (AIES, 1, "vectorized"), + (AIES, 2, "parallel"), + (ESS, 10, "sequential"), + (ESS, 1, "vectorized"), + (ESS, 2, "parallel")]) +def test_chain_smoke(kernel_cls, n_chain, method): + kernel = kernel_cls(model) + + mcmc = MCMC(kernel, num_warmup=10, num_samples=10, + progress_bar=False, num_chains=n_chain, chain_method=method) + + with pytest.raises(AssertionError, match="chain_method"): + mcmc.run(random.PRNGKey(2), labels) + +@pytest.mark.parametrize("kernel_cls", [AIES, ESS]) +def test_out_shape_smoke(kernel_cls): + n_chains = 10 + kernel = kernel_cls(model) + + mcmc = MCMC(kernel, num_warmup=10, num_samples=10, + progress_bar=False, num_chains=n_chains, chain_method='vectorized') + mcmc.run(random.PRNGKey(2), labels) + + assert (mcmc.get_samples(group_by_chain=True)['coefs'].shape[0] + == n_chains) + +@pytest.mark.parametrize("kernel_cls", [AIES, ESS]) +def test_invalid_moves(kernel_cls): + with pytest.raises(AssertionError, match="Each move"): + kernel_cls(model, moves={'invalid': 1.}) + diff --git a/test/infer/test_ensemble_util.py b/test/infer/test_ensemble_util.py new file mode 100644 index 000000000..5a066c69d --- /dev/null +++ b/test/infer/test_ensemble_util.py @@ -0,0 +1,35 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import jax +import jax.numpy as jnp + +from numpyro.infer.ensemble_util import batch_ravel_pytree, get_nondiagonal_indices + + +def test_nondiagonal_indices(): + truth = jnp.array( + [[1, 0], + [2, 0], + [2, 1], + [0, 1], + [0, 2], + [1, 2]], dtype=jnp.int32) + + assert jnp.all(get_nondiagonal_indices(3) == truth) + +def test_batch_ravel_pytree(): + arr1 = jnp.arange(10).reshape((5, 2)) + arr2 = jnp.arange(15).reshape((5, 3)) + arr3 = jnp.arange(20).reshape((5, 4)) + + tree = {'arr1': arr1, 'arr2': arr2, 'arr3': arr3} + + flattened, unravel_fn = batch_ravel_pytree(tree) + unflattened = unravel_fn(flattened) + + assert flattened.shape == (5, 2 + 3 + 4) + + for unflattened_leaf, original_leaf in zip(jax.tree_util.tree_leaves(unflattened), + jax.tree_util.tree_leaves(tree)): + assert jnp.all(unflattened_leaf == original_leaf) diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py index 04a5bc64d..3f52c6c1e 100644 --- a/test/infer/test_mcmc.py +++ b/test/infer/test_mcmc.py @@ -17,7 +17,7 @@ import numpyro import numpyro.distributions as dist from numpyro.distributions.transforms import AffineTransform -from numpyro.infer import HMC, MCMC, NUTS, SA, BarkerMH +from numpyro.infer import AIES, ESS, HMC, MCMC, NUTS, SA, BarkerMH from numpyro.infer.hmc import hmc from numpyro.infer.reparam import TransformReparam from numpyro.infer.sa import _get_proposal_loc_and_scale, _numpy_delete @@ -25,7 +25,7 @@ from numpyro.util import fori_collect, is_prng_key -@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH]) +@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH, AIES, ESS]) @pytest.mark.parametrize("dense_mass", [False, True]) def test_unnormalized_normal_x64(kernel_cls, dense_mass): true_mean, true_std = 1.0, 0.5 @@ -34,16 +34,30 @@ def test_unnormalized_normal_x64(kernel_cls, dense_mass): def potential_fn(z): return 0.5 * jnp.sum(((z - true_mean) / true_std) ** 2) - init_params = jnp.array(0.0) - if kernel_cls in [SA, BarkerMH]: + if kernel_cls in [AIES, ESS]: + num_chains = 10 + kernel = kernel_cls(potential_fn=potential_fn) + + init_params = random.normal(random.PRNGKey(1), (num_chains,)) + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False, + num_chains=num_chains, chain_method='vectorized' + ) + elif kernel_cls in [SA, BarkerMH]: kernel = kernel_cls(potential_fn=potential_fn, dense_mass=dense_mass) + init_params = jnp.array(0.0) + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False + ) else: kernel = kernel_cls( potential_fn=potential_fn, trajectory_length=8, dense_mass=dense_mass ) - mcmc = MCMC( - kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False - ) + init_params = jnp.array(0.0) + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False + ) + mcmc.run(random.PRNGKey(0), init_params=init_params) mcmc.print_summary() hmc_states = mcmc.get_samples() @@ -83,15 +97,13 @@ def potential_fn(z): assert np.sum(np.abs(np.cov(samples.T) - true_cov)) / D**2 < 0.02 -@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH]) +@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH, AIES, ESS]) def test_logistic_regression_x64(kernel_cls): + if kernel_cls in [AIES, ESS] and "CI" in os.environ: + pytest.skip("reduce time for CI.") + N, dim = 3000, 3 - if kernel_cls is SA: - num_warmup, num_samples = (100000, 100000) - elif kernel_cls is BarkerMH: - num_warmup, num_samples = (2000, 12000) - else: - num_warmup, num_samples = (1000, 8000) + data = random.normal(random.PRNGKey(0), (N, dim)) true_coefs = jnp.arange(1.0, dim + 1.0) logits = jnp.sum(true_coefs * data, axis=-1) @@ -102,17 +114,40 @@ def model(labels): logits = numpyro.deterministic("logits", jnp.sum(coefs * data, axis=-1)) return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels) - if kernel_cls is SA: + if kernel_cls in [AIES, ESS]: + if kernel_cls is AIES: + num_chains = 16 + else: + num_chains = 10 + samples_each_chain = 8000 + num_warmup, num_samples = (10_000, samples_each_chain * num_chains) + kernel = kernel_cls(model) + + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=samples_each_chain, + progress_bar=False, num_chains=num_chains, chain_method='vectorized' + ) + elif kernel_cls is SA: + num_warmup, num_samples = (100000, 100000) kernel = SA(model=model, adapt_state_size=9) + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False + ) elif kernel_cls is BarkerMH: + num_warmup, num_samples = (2000, 12000) kernel = BarkerMH(model=model) + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False + ) else: + num_warmup, num_samples = (1000, 8000) kernel = kernel_cls( model=model, trajectory_length=8, find_heuristic_step_size=True ) - mcmc = MCMC( - kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False - ) + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False + ) + mcmc.run(random.PRNGKey(2), labels) mcmc.print_summary() samples = mcmc.get_samples() @@ -185,10 +220,12 @@ def model(data): assert_allclose(jnp.mean(samples["loc"], 0), true_coef, atol=0.007) -@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH]) +@pytest.mark.parametrize("kernel_cls", [HMC, NUTS, SA, BarkerMH, AIES, ESS]) def test_beta_bernoulli_x64(kernel_cls): if kernel_cls is SA and "CI" in os.environ and "JAX_ENABLE_X64" in os.environ: pytest.skip("The test is flaky on CI x64.") + if kernel_cls is ESS and "CI" in os.environ: + pytest.skip("reduce time for CI.") num_warmup, num_samples = (100000, 100000) if kernel_cls is SA else (500, 20000) def model(data): @@ -200,15 +237,29 @@ def model(data): true_probs = jnp.array([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000,)) - if kernel_cls is SA: + + if kernel_cls in [AIES, ESS]: + num_chains = 10 + kernel = kernel_cls(model=model) + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=num_samples, + progress_bar=False, num_chains=num_chains, chain_method='vectorized' + ) + elif kernel_cls is SA: kernel = SA(model=model) + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False + ) elif kernel_cls is BarkerMH: kernel = BarkerMH(model=model) + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False + ) else: kernel = kernel_cls(model=model, trajectory_length=0.1) - mcmc = MCMC( - kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False - ) + mcmc = MCMC( + kernel, num_warmup=num_warmup, num_samples=num_samples, progress_bar=False + ) mcmc.run(random.PRNGKey(2), data) mcmc.print_summary() samples = mcmc.get_samples() diff --git a/test/test_distributions.py b/test/test_distributions.py index 714fe2871..f80f6bcf6 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1744,7 +1744,7 @@ def test_mean_var(jax_dist, sp_dist, params): sp_var = jnp.diag(d_sp.cov()) except TypeError: # mvn does not have .cov() method sp_var = jnp.diag(d_sp.cov) - except AttributeError: + except (AttributeError, ValueError): sp_var = d_sp.var() else: sp_var = d_sp.var()