diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c4f99fa16..47fabe4e7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -102,7 +102,7 @@ jobs: run: | pytest -vs --durations=20 test/infer/test_mcmc.py pytest -vs --durations=20 test/infer --ignore=test/infer/test_mcmc.py - pytest -vs --durations=20 test/contrib + pytest -vs --durations=20 test/contrib --ignore=test/contrib/stochastic_support/test_dcc.py - name: Test x64 run: | JAX_ENABLE_X64=1 pytest -vs test/infer/test_mcmc.py -k x64 @@ -110,6 +110,7 @@ jobs: run: | XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_mcmc.py -k "chain or pmap or vmap" XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/test_tfp.py -k "chain" + XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/stochastic_support/test_dcc.py XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_hmc_gibbs.py -k "chain" - name: Test custom prng run: | diff --git a/docs/source/contrib.rst b/docs/source/contrib.rst index a7b69db6a..7d225c850 100644 --- a/docs/source/contrib.rst +++ b/docs/source/contrib.rst @@ -74,3 +74,11 @@ SteinVI Kernels .. autoclass:: numpyro.contrib.einstein.stein_kernels.ProbabilityProductKernel +Stochastic Support +~~~~~~~~~~~~~~~~~~ + +.. autoclass:: numpyro.contrib.stochastic_support.dcc.DCC + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource \ No newline at end of file diff --git a/numpyro/contrib/stochastic_support/__init__.py b/numpyro/contrib/stochastic_support/__init__.py new file mode 100644 index 000000000..6c1dc37f9 --- /dev/null +++ b/numpyro/contrib/stochastic_support/__init__.py @@ -0,0 +1,8 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from numpyro.contrib.stochastic_support.dcc import DCC + +__all__ = [ + "DCC", +] diff --git a/numpyro/contrib/stochastic_support/dcc.py b/numpyro/contrib/stochastic_support/dcc.py new file mode 100644 index 000000000..7f7c68109 --- /dev/null +++ b/numpyro/contrib/stochastic_support/dcc.py @@ -0,0 +1,180 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from collections import OrderedDict, namedtuple + +import jax +from jax import random +import jax.numpy as jnp + +import numpyro.distributions as dist +from numpyro.handlers import condition, seed, trace +from numpyro.infer import MCMC, NUTS +from numpyro.infer.autoguide import AutoNormal +from numpyro.infer.util import init_to_value, log_density + +DCCResult = namedtuple("DCCResult", ["samples", "slp_weights"]) + + +class DCC: + """ + Implements the Divide, Conquer, and Combine (DCC) algorithm for models with + stochastic support from [1]. + + .. note:: This implementation assumes that all stochastic branching is done based on the + outcomes of discrete sampling sites that are annotated with `infer={"branching": True}`. + For example, + + .. code-block:: python + + def model(): + model1 = numpyro.sample("model1", dist.Bernoulli(0.5), infer={"branching": True}) + if model1 == 0: + mean = numpyro.sample("a1", dist.Normal(0.0, 1.0)) + else: + mean = numpyro.sample("a2", dist.Normal(1.0, 1.0)) + numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2) + + + + **References:** + + 1. *Divide, Conquer, and Combine: a New Inference Strategy for Probabilistic Programs with Stochastic Support*, + Yuan Zhou, Hongseok Yang, Yee Whye Teh, Tom Rainforth + + :param model: Python callable containing Pyro primitives :mod:`~numpyro.primitives`. + :param dict mcmc_kwargs: Dictionary of arguments passed to :data:`~numpyro.infer.MCMC`. + :param numpyro.infer.mcmc.MCMCKernel kernel_cls: MCMC kernel class that is used for + local inference. Defaults to :class:`~numpyro.infer.NUTS`. + :param int num_slp_samples: Number of samples to draw from the prior to discover the + straight-line programs (SLPs). + :param int max_slps: Maximum number of SLPs to discover. DCC will not run inference + on more than `max_slps`. + :param float proposal_scale: Scale parameter for the proposal distribution for + estimating the normalization constant of an SLP. + """ + + def __init__( + self, + model, + mcmc_kwargs, + kernel_cls=NUTS, + num_slp_samples=1000, + max_slps=124, + proposal_scale=1.0, + ): + self.model = model + self.kernel_cls = kernel_cls + self.mcmc_kwargs = mcmc_kwargs + + self.num_slp_samples = num_slp_samples + self.max_slps = max_slps + self.proposal_scale = proposal_scale + + def _find_slps(self, rng_key, *args, **kwargs): + """ + Discover the straight-line programs (SLPs) in the model by sampling from the prior. + This implementation assumes that all branching is done via discrete sampling sites + that are annotated with `infer={"branching": True}`. + """ + branching_traces = {} + for _ in range(self.num_slp_samples): + rng_key, subkey = random.split(rng_key) + tr = trace(seed(self.model, subkey)).get_trace(*args, **kwargs) + btr = self._get_branching_trace(tr) + btr_str = ",".join(str(x) for x in btr.values()) + if btr_str not in branching_traces: + branching_traces[btr_str] = btr + if len(branching_traces) >= self.max_slps: + break + + return branching_traces + + def _get_branching_trace(self, tr): + """ + Extract the sites from the trace that are annotated with `infer={"branching": True}`. + """ + branching_trace = OrderedDict() + for site in tr.values(): + if site["type"] == "sample" and site["infer"].get("branching", False): + if ( + not isinstance(site["fn"], dist.Distribution) + or not site["fn"].support.is_discrete + ): + raise RuntimeError( + "Branching is only supported for discrete sampling sites." + ) + # It is essential that we convert the value to a Python int. If it remains + # a JAX Array, then during JIT compilation it will be treated as an AbstractArray + # which means branching will raise in an error. + # Reference: (https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit) + branching_trace[site["name"]] = int(site["value"]) + return branching_trace + + def _run_mcmc(self, rng_key, branching_trace, *args, **kwargs): + """ + Run MCMC on the model conditioned on the given branching trace. + """ + slp_model = condition(self.model, data=branching_trace) + kernel = self.kernel_cls(slp_model) + mcmc = MCMC(kernel, **self.mcmc_kwargs) + mcmc.run(rng_key, *args, **kwargs) + + return mcmc.get_samples() + + def _combine_samples(self, rng_key, samples, branching_traces, *args, **kwargs): + """ + Weight each SLP proportional to its estimated normalization constant. + The normalization constants are estimated using importance sampling with + the proposal centred on the MCMC samples. This is a special case of the + layered adaptive importance sampling algorithm from [1]. + + **References:** + 1. *Layered adaptive importance sampling*, + Luca Martino, Victor Elvira, David Luengo, and Jukka Corander. + """ + + def log_weight(rng_key, i, slp_model, slp_samples): + trace = {k: v[i] for k, v in slp_samples.items()} + guide = AutoNormal( + slp_model, + init_loc_fn=init_to_value(values=trace), + init_scale=self.proposal_scale, + ) + rng_key, subkey = random.split(rng_key) + guide_trace = seed(guide, subkey)(*args, **kwargs) + guide_log_density, _ = log_density(guide, args, kwargs, guide_trace) + model_log_density, _ = log_density(slp_model, args, kwargs, guide_trace) + return model_log_density - guide_log_density + + log_weights = jax.vmap(log_weight, in_axes=(None, 0, None, None)) + + log_Zs = {} + for bt, slp_samples in samples.items(): + num_samples = slp_samples[next(iter(slp_samples))].shape[0] + slp_model = condition(self.model, data=branching_traces[bt]) + lws = log_weights(rng_key, jnp.arange(num_samples), slp_model, slp_samples) + log_Zs[bt] = jax.scipy.special.logsumexp(lws) - jnp.log(num_samples) + + normalizer = jax.scipy.special.logsumexp(jnp.array(list(log_Zs.values()))) + slp_weights = {k: jnp.exp(v - normalizer) for k, v in log_Zs.items()} + return DCCResult(samples, slp_weights) + + def run(self, rng_key, *args, **kwargs): + """ + Run DCC and collect samples for all SLPs. + + :param jax.random.PRNGKey rng_key: Random number generator key. + :param args: Arguments to the model. + :param kwargs: Keyword arguments to the model. + """ + rng_key, subkey = random.split(rng_key) + branching_traces = self._find_slps(subkey, *args, **kwargs) + + samples = dict() + for key, bt in branching_traces.items(): + rng_key, subkey = random.split(rng_key) + samples[key] = self._run_mcmc(subkey, bt, *args, **kwargs) + + rng_key, subkey = random.split(rng_key) + return self._combine_samples(subkey, samples, branching_traces, *args, **kwargs) diff --git a/test/contrib/stochastic_support/test_dcc.py b/test/contrib/stochastic_support/test_dcc.py new file mode 100644 index 000000000..d56e39b98 --- /dev/null +++ b/test/contrib/stochastic_support/test_dcc.py @@ -0,0 +1,201 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import math + +from numpy.testing import assert_allclose +import pytest + +import jax +from jax import random +import jax.numpy as jnp + +import numpyro +from numpyro.contrib.stochastic_support.dcc import DCC +import numpyro.distributions as dist +from numpyro.infer import HMC, NUTS, SA, BarkerMH + + +@pytest.mark.parametrize( + "branch_dist", + [dist.Normal(0, 1), dist.Gamma(1, 1)], +) +@pytest.mark.xfail(raises=RuntimeError) +def test_continuous_branching(branch_dist): + rng_key = random.PRNGKey(0) + + def model(): + model1 = numpyro.sample("model1", branch_dist, infer={"branching": True}) + mean = 1.0 if model1 == 0 else 2.0 + numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2) + + mcmc_kwargs = dict( + num_warmup=500, + num_samples=1000, + num_chains=1, + ) + + dcc = DCC(model, mcmc_kwargs=mcmc_kwargs) + rng_key, subkey = random.split(rng_key) + dcc.run(subkey) + + +def test_different_address_path(): + rng_key = random.PRNGKey(0) + + def model(): + model1 = numpyro.sample( + "model1", dist.Bernoulli(0.5), infer={"branching": True} + ) + if model1 == 0: + numpyro.sample("a1", dist.Normal(9.0, 1.0)) + else: + numpyro.sample("a2", dist.Normal(9.0, 1.0)) + numpyro.sample("a3", dist.Normal(9.0, 1.0)) + mean = 1.0 if model1 == 0 else 2.0 + numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2) + + mcmc_kwargs = dict( + num_warmup=50, + num_samples=50, + num_chains=1, + progress_bar=False, + ) + + dcc = DCC(model, mcmc_kwargs=mcmc_kwargs) + rng_key, subkey = random.split(rng_key) + dcc.run(subkey) + + +@pytest.mark.parametrize("proposal_scale", [0.1, 1.0, 10.0]) +def test_proposal_scale(proposal_scale): + def model(y): + z = numpyro.sample("z", dist.Normal(0.0, 1.0)) + model1 = numpyro.sample( + "model1", dist.Bernoulli(0.5), infer={"branching": True} + ) + sigma = 1.0 if model1 == 0 else 2.0 + with numpyro.plate("data", y.shape[0]): + numpyro.sample("obs", dist.Normal(z, sigma), obs=y) + + rng_key = random.PRNGKey(0) + + rng_key, subkey = random.split(rng_key) + y_train = dist.Normal(0, 1).sample(subkey, (200,)) + + mcmc_kwargs = dict( + num_warmup=50, + num_samples=50, + num_chains=2, + progress_bar=False, + ) + + dcc = DCC(model, mcmc_kwargs=mcmc_kwargs, proposal_scale=proposal_scale) + rng_key, subkey = random.split(rng_key) + dcc.run(subkey, y_train) + + +@pytest.mark.parametrize( + "chain_method", + ["sequential", "parallel", "vectorized"], +) +@pytest.mark.parametrize("kernel_cls", [NUTS, HMC, SA, BarkerMH]) +def test_kernels(chain_method, kernel_cls): + if chain_method == "vectorized" and kernel_cls in [SA, BarkerMH]: + # These methods do not support vectorized execution. + return + + def model(y): + z = numpyro.sample("z", dist.Normal(0.0, 1.0)) + model1 = numpyro.sample( + "model1", dist.Bernoulli(0.5), infer={"branching": True} + ) + sigma = 1.0 if model1 == 0 else 2.0 + with numpyro.plate("data", y.shape[0]): + numpyro.sample("obs", dist.Normal(z, sigma), obs=y) + + rng_key = random.PRNGKey(0) + + rng_key, subkey = random.split(rng_key) + y_train = dist.Normal(0, 1).sample(subkey, (200,)) + + mcmc_kwargs = dict( + num_warmup=50, + num_samples=50, + num_chains=2, + chain_method=chain_method, + progress_bar=False, + ) + + dcc = DCC(model, mcmc_kwargs=mcmc_kwargs, kernel_cls=kernel_cls) + rng_key, subkey = random.split(rng_key) + dcc.run(subkey, y_train) + + +def test_weight_convergence(): + PRIOR_MEAN, PRIOR_STD = 0.0, 1.0 + LIKELIHOOD1_STD = 2.0 + LIKELIHOOD2_STD = 0.62177 + + def log_marginal_likelihood(data, likelihood_std, prior_mean, prior_std): + """ + Calculate the marginal likelihood of a model with Normal likelihood, unknown mean, + and Normal prior. + + Taken from Section 2.5 at https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf. + """ + num_data = data.shape[0] + likelihood_var = jnp.square(likelihood_std) + prior_var = jnp.square(prior_std) + + first_term = ( + jnp.log(likelihood_std) + - num_data * jnp.log(jnp.sqrt(2 * math.pi) * likelihood_std) + + 0.5 * jnp.log(num_data * prior_var + likelihood_var) + ) + second_term = -(jnp.sum(jnp.square(data)) / (2 * likelihood_var)) - ( + jnp.square(prior_mean) / (2 * prior_var) + ) + third_term = ( + ( + prior_var + * jnp.square(num_data) + * jnp.square(jnp.mean(data)) + / likelihood_var + ) + + (likelihood_var * jnp.square(prior_mean) / prior_var) + + 2 * num_data * jnp.mean(data) * prior_mean + ) / (2 * (num_data * prior_var + likelihood_var)) + return first_term + second_term + third_term + + def model(y): + z = numpyro.sample("z", dist.Normal(PRIOR_MEAN, PRIOR_STD)) + model1 = numpyro.sample( + "model1", dist.Bernoulli(0.5), infer={"branching": True} + ) + sigma = LIKELIHOOD1_STD if model1 == 0 else LIKELIHOOD2_STD + with numpyro.plate("data", y.shape[0]): + numpyro.sample("obs", dist.Normal(z, sigma), obs=y) + + rng_key = random.PRNGKey(0) + + rng_key, subkey = random.split(rng_key) + y_train = dist.Normal(0, 1).sample(subkey, (200,)) + + mcmc_kwargs = dict( + num_warmup=500, + num_samples=1000, + num_chains=1, + ) + + dcc = DCC(model, mcmc_kwargs=mcmc_kwargs) + rng_key, subkey = random.split(rng_key) + dcc_result = dcc.run(subkey, y_train) + slp_weights = jnp.array([dcc_result.slp_weights["0"], dcc_result.slp_weights["1"]]) + assert_allclose(1.0, jnp.sum(slp_weights)) + + slp1_lml = log_marginal_likelihood(y_train, LIKELIHOOD1_STD, PRIOR_MEAN, PRIOR_STD) + slp2_lml = log_marginal_likelihood(y_train, LIKELIHOOD2_STD, PRIOR_MEAN, PRIOR_STD) + lmls = jnp.array([slp1_lml, slp2_lml]) + analytic_weights = jnp.exp(lmls - jax.scipy.special.logsumexp(lmls)) + assert_allclose(analytic_weights, slp_weights, rtol=1e-5, atol=1e-8)