From 2ceb88dfb2c82002d300c6bd9b1143a644fbab1b Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Sun, 7 Jan 2024 09:23:04 +0000 Subject: [PATCH 1/6] Initial bare bones implementation of DCC --- examples/dcc.py | 80 +++++++++++++++++++++++ numpyro/contrib/dcc/dcc.py | 127 +++++++++++++++++++++++++++++++++++++ 2 files changed, 207 insertions(+) create mode 100644 examples/dcc.py create mode 100644 numpyro/contrib/dcc/dcc.py diff --git a/examples/dcc.py b/examples/dcc.py new file mode 100644 index 000000000..cdc072a5c --- /dev/null +++ b/examples/dcc.py @@ -0,0 +1,80 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import math + +import jax +import jax.numpy as jnp +from jax import random + +import numpyro +import numpyro.distributions as dist +from numpyro.contrib.dcc.dcc import DCC + +PRIOR_MEAN, PRIOR_STD = 0.0, 1.0 +LIKELIHOOD1_STD = 2.0 +LIKELIHOOD2_STD = 0.62177 + + +def model(y): + z = numpyro.sample("z", dist.Normal(PRIOR_MEAN, PRIOR_STD)) + model1 = numpyro.sample("model1", dist.Bernoulli(0.5), infer={"branching": True}) + sigma = LIKELIHOOD1_STD if model1 == 0 else LIKELIHOOD2_STD + with numpyro.plate("data", y.shape[0]): + numpyro.sample("obs", dist.Normal(z, sigma), obs=y) + + +def log_marginal_likelihood(data, likelihood_std, prior_mean, prior_std): + """ + Calculate the marginal likelihood of a model with Normal likelihood, unknown mean, + and Normal prior. + + Taken from Section 2.5 at https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf. + """ + num_data = data.shape[0] + likelihood_var = jnp.square(likelihood_std) + prior_var = jnp.square(prior_std) + + first_term = ( + jnp.log(likelihood_std) + - num_data * jnp.log(jnp.sqrt(2 * math.pi) * likelihood_std) + + 0.5 * jnp.log(num_data * prior_var + likelihood_var) + ) + second_term = -(jnp.sum(jnp.square(data)) / (2 * likelihood_var)) - ( + jnp.square(prior_mean) / (2 * prior_var) + ) + third_term = ( + (prior_var * jnp.square(num_data) * jnp.square(jnp.mean(data)) / likelihood_var) + + (likelihood_var * jnp.square(prior_mean) / prior_var) + + 2 * num_data * jnp.mean(data) * prior_mean + ) / (2 * (num_data * prior_var + likelihood_var)) + return first_term + second_term + third_term + + +def main(): + rng_key = random.PRNGKey(0) + + rng_key, subkey = random.split(rng_key) + y_train = dist.Normal(0, 1).sample(subkey, (200,)) + + mcmc_kwargs = dict( + num_warmup=500, + num_samples=1000, + num_chains=2, + ) + + dcc = DCC(model, mcmc_kwargs=mcmc_kwargs) + rng_key, subkey = random.split(rng_key) + dcc_result = dcc.run(subkey, y_train) + slp_weights = jnp.array([dcc_result.slp_weights["0"], dcc_result.slp_weights["1"]]) + assert jnp.allclose(1.0, jnp.sum(slp_weights)) + + slp1_lml = log_marginal_likelihood(y_train, LIKELIHOOD1_STD, PRIOR_MEAN, PRIOR_STD) + slp2_lml = log_marginal_likelihood(y_train, LIKELIHOOD2_STD, PRIOR_MEAN, PRIOR_STD) + lmls = jnp.array([slp1_lml, slp2_lml]) + analytic_weights = jnp.exp(lmls - jax.scipy.special.logsumexp(lmls)) + assert jnp.allclose(analytic_weights, slp_weights) + + +if __name__ == "__main__": + main() diff --git a/numpyro/contrib/dcc/dcc.py b/numpyro/contrib/dcc/dcc.py new file mode 100644 index 000000000..0a90fcb37 --- /dev/null +++ b/numpyro/contrib/dcc/dcc.py @@ -0,0 +1,127 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from collections import OrderedDict, namedtuple + +import jax +import jax.numpy as jnp +from jax import random + +from numpyro.handlers import condition, seed, trace +from numpyro.infer import MCMC, NUTS +from numpyro.infer.autoguide import AutoNormal +from numpyro.infer.util import init_to_value, log_density + +DCCResult = namedtuple("DCCResult", ["samples", "slp_weights"]) + + +class DCC: + """ + Implements the Divide, Conquer, and Combine (DCC) algorithm from [1]. + + **References:** + 1. *Divide, Conquer, and Combine: a New Inference Strategy for Probabilistic Programs with Stochastic Support*, + Yuan Zhou, Hongseok Yang, Yee Whye Teh, Tom Rainforth + """ + + def __init__( + self, + model, + mcmc_kwargs, + num_slp_samples=1000, + max_slps=124, + ): + self.model = model + self.mcmc_kwargs = mcmc_kwargs + + self.num_slp_samples = num_slp_samples + self.max_slps = max_slps + + def _find_slps(self, rng_key, *args, **kwargs): + """ + Discover the straight-line programs (SLPs) in the model by sampling from the prior. + This implementation assumes that all branching is done via discrete sampling sites + that are annotated with `infer={"branching": True}`. + """ + branching_traces = {} + for _ in range(self.num_slp_samples): + rng_key, subkey = random.split(rng_key) + tr = trace(seed(self.model, subkey)).get_trace(*args, **kwargs) + btr = self._get_branching_trace(tr) + btr_str = ",".join(str(x) for x in btr.values()) + if btr_str not in branching_traces: + branching_traces[btr_str] = btr + if len(branching_traces) >= self.max_slps: + break + + return branching_traces + + def _get_branching_trace(self, tr): + """ + Extract the sites from the trace that are annotated with `infer={"branching": True}`. + """ + branching_trace = OrderedDict() + for site in tr.values(): + if site["type"] == "sample" and site["infer"].get("branching", False): + # TODO: Assert that this is a discrete sampling site and univariate distribution. + # It is essential that we convert the value to a Python int. If it remains + # a JAX Array, then during JIT compilation it will be treated as an AbstractArray + # and we are not able to branch based on this value. + branching_trace[site["name"]] = int(site["value"]) + return branching_trace + + def _run_mcmc(self, rng_key, branching_trace, *args, **kwargs): + """ + Run MCMC on the model conditioned on the given branching trace. + """ + slp_model = condition(self.model, data=branching_trace) + kernel = NUTS(slp_model) + mcmc = MCMC(kernel, **self.mcmc_kwargs) + mcmc.run(rng_key, *args, **kwargs) + + return mcmc.get_samples() + + def _combine_samples(self, rng_key, samples, branching_traces, *args, **kwargs): + """ + Weight each SLP proportional to its estimated normalization constant. + The normalization constants are estimated using importance sampling with + the proposal centered on the MCMC samples. + """ + + def log_weight(rng_key, i, slp_model, slp_samples): + trace = {k: v[i] for k, v in slp_samples.items()} + guide = AutoNormal( + slp_model, + init_loc_fn=init_to_value(values=trace), + init_scale=1.0, + ) + rng_key, subkey = random.split(rng_key) + guide_trace = seed(guide, subkey)(*args, **kwargs) + guide_log_density, _ = log_density(guide, args, kwargs, guide_trace) + model_log_density, _ = log_density(slp_model, args, kwargs, guide_trace) + return model_log_density - guide_log_density + + log_weights = jax.vmap(log_weight, in_axes=(None, 0, None, None)) + + log_Zs = {} + for bt, slp_samples in samples.items(): + num_samples = slp_samples[next(iter(slp_samples))].shape[0] + slp_model = condition(self.model, data=branching_traces[bt]) + lws = log_weights(rng_key, jnp.arange(num_samples), slp_model, slp_samples) + log_Zs[bt] = jax.scipy.special.logsumexp(lws) - jnp.log(num_samples) + + normalizer = jax.scipy.special.logsumexp(jnp.array(list(log_Zs.values()))) + slp_weights = {k: jnp.exp(v - normalizer) for k, v in log_Zs.items()} + return DCCResult(samples, slp_weights) + + def run(self, rng_key, *args, **kwargs): + rng_key, subkey = random.split(rng_key) + branching_traces = self._find_slps(subkey, *args, **kwargs) + + samples = dict() + for key, bt in branching_traces.items(): + rng_key, subkey = random.split(rng_key) + samples[key] = self._run_mcmc(subkey, bt, *args, **kwargs) + + rng_key, subkey = random.split(rng_key) + return self._combine_samples(subkey, samples, branching_traces, *args, **kwargs) From 045034088510e92443800cd4a376411b23004b2a Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Wed, 17 Jan 2024 09:58:15 +0000 Subject: [PATCH 2/6] Add tests and documentation --- docs/source/contrib.rst | 8 + examples/dcc.py | 80 --------- .../{dcc => stochastic_support}/dcc.py | 52 +++++- test/contrib/stochastic_support/test_dcc.py | 169 ++++++++++++++++++ 4 files changed, 225 insertions(+), 84 deletions(-) delete mode 100644 examples/dcc.py rename numpyro/contrib/{dcc => stochastic_support}/dcc.py (69%) create mode 100644 test/contrib/stochastic_support/test_dcc.py diff --git a/docs/source/contrib.rst b/docs/source/contrib.rst index a7b69db6a..7d225c850 100644 --- a/docs/source/contrib.rst +++ b/docs/source/contrib.rst @@ -74,3 +74,11 @@ SteinVI Kernels .. autoclass:: numpyro.contrib.einstein.stein_kernels.ProbabilityProductKernel +Stochastic Support +~~~~~~~~~~~~~~~~~~ + +.. autoclass:: numpyro.contrib.stochastic_support.dcc.DCC + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource \ No newline at end of file diff --git a/examples/dcc.py b/examples/dcc.py deleted file mode 100644 index cdc072a5c..000000000 --- a/examples/dcc.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright Contributors to the Pyro project. -# SPDX-License-Identifier: Apache-2.0 - -import math - -import jax -import jax.numpy as jnp -from jax import random - -import numpyro -import numpyro.distributions as dist -from numpyro.contrib.dcc.dcc import DCC - -PRIOR_MEAN, PRIOR_STD = 0.0, 1.0 -LIKELIHOOD1_STD = 2.0 -LIKELIHOOD2_STD = 0.62177 - - -def model(y): - z = numpyro.sample("z", dist.Normal(PRIOR_MEAN, PRIOR_STD)) - model1 = numpyro.sample("model1", dist.Bernoulli(0.5), infer={"branching": True}) - sigma = LIKELIHOOD1_STD if model1 == 0 else LIKELIHOOD2_STD - with numpyro.plate("data", y.shape[0]): - numpyro.sample("obs", dist.Normal(z, sigma), obs=y) - - -def log_marginal_likelihood(data, likelihood_std, prior_mean, prior_std): - """ - Calculate the marginal likelihood of a model with Normal likelihood, unknown mean, - and Normal prior. - - Taken from Section 2.5 at https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf. - """ - num_data = data.shape[0] - likelihood_var = jnp.square(likelihood_std) - prior_var = jnp.square(prior_std) - - first_term = ( - jnp.log(likelihood_std) - - num_data * jnp.log(jnp.sqrt(2 * math.pi) * likelihood_std) - + 0.5 * jnp.log(num_data * prior_var + likelihood_var) - ) - second_term = -(jnp.sum(jnp.square(data)) / (2 * likelihood_var)) - ( - jnp.square(prior_mean) / (2 * prior_var) - ) - third_term = ( - (prior_var * jnp.square(num_data) * jnp.square(jnp.mean(data)) / likelihood_var) - + (likelihood_var * jnp.square(prior_mean) / prior_var) - + 2 * num_data * jnp.mean(data) * prior_mean - ) / (2 * (num_data * prior_var + likelihood_var)) - return first_term + second_term + third_term - - -def main(): - rng_key = random.PRNGKey(0) - - rng_key, subkey = random.split(rng_key) - y_train = dist.Normal(0, 1).sample(subkey, (200,)) - - mcmc_kwargs = dict( - num_warmup=500, - num_samples=1000, - num_chains=2, - ) - - dcc = DCC(model, mcmc_kwargs=mcmc_kwargs) - rng_key, subkey = random.split(rng_key) - dcc_result = dcc.run(subkey, y_train) - slp_weights = jnp.array([dcc_result.slp_weights["0"], dcc_result.slp_weights["1"]]) - assert jnp.allclose(1.0, jnp.sum(slp_weights)) - - slp1_lml = log_marginal_likelihood(y_train, LIKELIHOOD1_STD, PRIOR_MEAN, PRIOR_STD) - slp2_lml = log_marginal_likelihood(y_train, LIKELIHOOD2_STD, PRIOR_MEAN, PRIOR_STD) - lmls = jnp.array([slp1_lml, slp2_lml]) - analytic_weights = jnp.exp(lmls - jax.scipy.special.logsumexp(lmls)) - assert jnp.allclose(analytic_weights, slp_weights) - - -if __name__ == "__main__": - main() diff --git a/numpyro/contrib/dcc/dcc.py b/numpyro/contrib/stochastic_support/dcc.py similarity index 69% rename from numpyro/contrib/dcc/dcc.py rename to numpyro/contrib/stochastic_support/dcc.py index 0a90fcb37..e53b728df 100644 --- a/numpyro/contrib/dcc/dcc.py +++ b/numpyro/contrib/stochastic_support/dcc.py @@ -7,6 +7,7 @@ import jax.numpy as jnp from jax import random +import numpyro.distributions as dist from numpyro.handlers import condition, seed, trace from numpyro.infer import MCMC, NUTS from numpyro.infer.autoguide import AutoNormal @@ -17,21 +18,50 @@ class DCC: """ - Implements the Divide, Conquer, and Combine (DCC) algorithm from [1]. + Implements the Divide, Conquer, and Combine (DCC) algorithm for models with + stochastic support from [1]. + + .. note:: This implementation assumes that all stochastic branching is done based on the + outcomes of discrete sampling sites that are annotated with `infer={"branching": True}`. + For example, + + .. code-block:: python + + def model(): + model1 = numpyro.sample("model1", dist.Bernoulli(0.5), infer={"branching": True}) + if model1 == 0: + mean = numpyro.sample("a1", dist.Normal(0.0, 1.0)) + else: + mean = numpyro.sample("a2", dist.Normal(1.0, 1.0)) + numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2) + + **References:** + 1. *Divide, Conquer, and Combine: a New Inference Strategy for Probabilistic Programs with Stochastic Support*, Yuan Zhou, Hongseok Yang, Yee Whye Teh, Tom Rainforth + + :param model: Python callable containing Pyro primitives :mod:`~numpyro.primitives`. + :param dict mcmc_kwargs: Dictionary of arguments passed to :data:`~numpyro.infer.MCMC`. + :param numpyro.infer.mcmc.MCMCKernel kernel_cls: MCMC kernel class that is used for + local inference. Defaults to :class:`~numpyro.infer.NUTS`. + :param int num_slp_samples: Number of samples to draw from the prior to discover the + straight-line programs (SLPs). + :param int max_slps: Maximum number of SLPs to discover. DCC will not run inference + on more than `max_slps`. """ def __init__( self, model, mcmc_kwargs, + kernel_cls=NUTS, num_slp_samples=1000, max_slps=124, ): self.model = model + self.kernel_cls = kernel_cls self.mcmc_kwargs = mcmc_kwargs self.num_slp_samples = num_slp_samples @@ -63,10 +93,17 @@ def _get_branching_trace(self, tr): branching_trace = OrderedDict() for site in tr.values(): if site["type"] == "sample" and site["infer"].get("branching", False): - # TODO: Assert that this is a discrete sampling site and univariate distribution. + if ( + not isinstance(site["fn"], dist.Distribution) + or not site["fn"].support.is_discrete + ): + raise RuntimeError( + "Branching is only supported for discrete sampling sites." + ) # It is essential that we convert the value to a Python int. If it remains # a JAX Array, then during JIT compilation it will be treated as an AbstractArray - # and we are not able to branch based on this value. + # which means branching will raise in an error. + # Reference: (https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit) branching_trace[site["name"]] = int(site["value"]) return branching_trace @@ -75,7 +112,7 @@ def _run_mcmc(self, rng_key, branching_trace, *args, **kwargs): Run MCMC on the model conditioned on the given branching trace. """ slp_model = condition(self.model, data=branching_trace) - kernel = NUTS(slp_model) + kernel = self.kernel_cls(slp_model) mcmc = MCMC(kernel, **self.mcmc_kwargs) mcmc.run(rng_key, *args, **kwargs) @@ -115,6 +152,13 @@ def log_weight(rng_key, i, slp_model, slp_samples): return DCCResult(samples, slp_weights) def run(self, rng_key, *args, **kwargs): + """ + Run DCC and collect samples for all SLPs. + + :param jax.random.PRNGKey rng_key: Random number generator key. + :param args: Arguments to the model. + :param kwargs: Keyword arguments to the model. + """ rng_key, subkey = random.split(rng_key) branching_traces = self._find_slps(subkey, *args, **kwargs) diff --git a/test/contrib/stochastic_support/test_dcc.py b/test/contrib/stochastic_support/test_dcc.py new file mode 100644 index 000000000..2fc24da42 --- /dev/null +++ b/test/contrib/stochastic_support/test_dcc.py @@ -0,0 +1,169 @@ +import math + +import jax +import jax.numpy as jnp +import pytest +from jax import random +from numpy.testing import assert_allclose + +import numpyro +import numpyro.distributions as dist +from numpyro.contrib.stochastic_support.dcc import DCC +from numpyro.infer import HMC, NUTS, SA, BarkerMH + + +@pytest.mark.parametrize( + "branch_dist", + [dist.Normal(0, 1), dist.Gamma(1, 1)], +) +@pytest.mark.xfail(raises=RuntimeError) +def test_continuous_branching(branch_dist): + rng_key = random.PRNGKey(0) + + def model(): + model1 = numpyro.sample("model1", branch_dist, infer={"branching": True}) + mean = 1.0 if model1 == 0 else 2.0 + numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2) + + mcmc_kwargs = dict( + num_warmup=500, + num_samples=1000, + num_chains=1, + ) + + dcc = DCC(model, mcmc_kwargs=mcmc_kwargs) + rng_key, subkey = random.split(rng_key) + dcc.run(subkey) + + +def test_different_address_path(): + rng_key = random.PRNGKey(0) + + def model(): + model1 = numpyro.sample( + "model1", dist.Bernoulli(0.5), infer={"branching": True} + ) + if model1 == 0: + numpyro.sample("a1", dist.Normal(9.0, 1.0)) + else: + numpyro.sample("a2", dist.Normal(9.0, 1.0)) + numpyro.sample("a3", dist.Normal(9.0, 1.0)) + mean = 1.0 if model1 == 0 else 2.0 + numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2) + + mcmc_kwargs = dict( + num_warmup=50, + num_samples=50, + num_chains=1, + progress_bar=False, + ) + + dcc = DCC(model, mcmc_kwargs=mcmc_kwargs) + rng_key, subkey = random.split(rng_key) + dcc.run(subkey) + + +@pytest.mark.parametrize( + "chain_method", + ["sequential", "parallel", "vectorized"], +) +@pytest.mark.parametrize("kernel_cls", [NUTS, HMC, SA, BarkerMH]) +def test_kernels(chain_method, kernel_cls): + if chain_method == "vectorized" and kernel_cls in [SA, BarkerMH]: + # These methods do not support vectorized execution. + return + + def model(y): + z = numpyro.sample("z", dist.Normal(0.0, 1.0)) + model1 = numpyro.sample( + "model1", dist.Bernoulli(0.5), infer={"branching": True} + ) + sigma = 1.0 if model1 == 0 else 2.0 + with numpyro.plate("data", y.shape[0]): + numpyro.sample("obs", dist.Normal(z, sigma), obs=y) + + rng_key = random.PRNGKey(0) + + rng_key, subkey = random.split(rng_key) + y_train = dist.Normal(0, 1).sample(subkey, (200,)) + + mcmc_kwargs = dict( + num_warmup=50, + num_samples=50, + num_chains=2, + chain_method=chain_method, + progress_bar=False, + ) + + dcc = DCC(model, mcmc_kwargs=mcmc_kwargs, kernel_cls=kernel_cls) + rng_key, subkey = random.split(rng_key) + dcc.run(subkey, y_train) + + +def test_weight_convergence(): + PRIOR_MEAN, PRIOR_STD = 0.0, 1.0 + LIKELIHOOD1_STD = 2.0 + LIKELIHOOD2_STD = 0.62177 + + def log_marginal_likelihood(data, likelihood_std, prior_mean, prior_std): + """ + Calculate the marginal likelihood of a model with Normal likelihood, unknown mean, + and Normal prior. + + Taken from Section 2.5 at https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf. + """ + num_data = data.shape[0] + likelihood_var = jnp.square(likelihood_std) + prior_var = jnp.square(prior_std) + + first_term = ( + jnp.log(likelihood_std) + - num_data * jnp.log(jnp.sqrt(2 * math.pi) * likelihood_std) + + 0.5 * jnp.log(num_data * prior_var + likelihood_var) + ) + second_term = -(jnp.sum(jnp.square(data)) / (2 * likelihood_var)) - ( + jnp.square(prior_mean) / (2 * prior_var) + ) + third_term = ( + ( + prior_var + * jnp.square(num_data) + * jnp.square(jnp.mean(data)) + / likelihood_var + ) + + (likelihood_var * jnp.square(prior_mean) / prior_var) + + 2 * num_data * jnp.mean(data) * prior_mean + ) / (2 * (num_data * prior_var + likelihood_var)) + return first_term + second_term + third_term + + def model(y): + z = numpyro.sample("z", dist.Normal(PRIOR_MEAN, PRIOR_STD)) + model1 = numpyro.sample( + "model1", dist.Bernoulli(0.5), infer={"branching": True} + ) + sigma = LIKELIHOOD1_STD if model1 == 0 else LIKELIHOOD2_STD + with numpyro.plate("data", y.shape[0]): + numpyro.sample("obs", dist.Normal(z, sigma), obs=y) + + rng_key = random.PRNGKey(0) + + rng_key, subkey = random.split(rng_key) + y_train = dist.Normal(0, 1).sample(subkey, (200,)) + + mcmc_kwargs = dict( + num_warmup=500, + num_samples=1000, + num_chains=1, + ) + + dcc = DCC(model, mcmc_kwargs=mcmc_kwargs) + rng_key, subkey = random.split(rng_key) + dcc_result = dcc.run(subkey, y_train) + slp_weights = jnp.array([dcc_result.slp_weights["0"], dcc_result.slp_weights["1"]]) + assert_allclose(1.0, jnp.sum(slp_weights)) + + slp1_lml = log_marginal_likelihood(y_train, LIKELIHOOD1_STD, PRIOR_MEAN, PRIOR_STD) + slp2_lml = log_marginal_likelihood(y_train, LIKELIHOOD2_STD, PRIOR_MEAN, PRIOR_STD) + lmls = jnp.array([slp1_lml, slp2_lml]) + analytic_weights = jnp.exp(lmls - jax.scipy.special.logsumexp(lmls)) + assert_allclose(analytic_weights, slp_weights, rtol=1e-5, atol=1e-8) From d1a435bc06151d4dcd6146fdcc38f2c8fec66843 Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Thu, 1 Feb 2024 19:01:06 +0000 Subject: [PATCH 3/6] Make scale in Normal proposal configurable --- numpyro/contrib/stochastic_support/dcc.py | 13 ++++++++-- test/contrib/stochastic_support/test_dcc.py | 28 +++++++++++++++++++++ 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/numpyro/contrib/stochastic_support/dcc.py b/numpyro/contrib/stochastic_support/dcc.py index e53b728df..08e1382fc 100644 --- a/numpyro/contrib/stochastic_support/dcc.py +++ b/numpyro/contrib/stochastic_support/dcc.py @@ -50,6 +50,8 @@ def model(): straight-line programs (SLPs). :param int max_slps: Maximum number of SLPs to discover. DCC will not run inference on more than `max_slps`. + :param float proposal_scale: Scale parameter for the proposal distribution for + estimating the normalization constant of an SLP. """ def __init__( @@ -59,6 +61,7 @@ def __init__( kernel_cls=NUTS, num_slp_samples=1000, max_slps=124, + proposal_scale=1.0, ): self.model = model self.kernel_cls = kernel_cls @@ -66,6 +69,7 @@ def __init__( self.num_slp_samples = num_slp_samples self.max_slps = max_slps + self.proposal_scale = proposal_scale def _find_slps(self, rng_key, *args, **kwargs): """ @@ -122,7 +126,12 @@ def _combine_samples(self, rng_key, samples, branching_traces, *args, **kwargs): """ Weight each SLP proportional to its estimated normalization constant. The normalization constants are estimated using importance sampling with - the proposal centered on the MCMC samples. + the proposal centred on the MCMC samples. This is a special case of the + layered adaptive importance sampling algorithm from [1]. + + **References:** + 1. *Layered adaptive importance sampling*, + Luca Martino, Victor Elvira, David Luengo, and Jukka Corander. """ def log_weight(rng_key, i, slp_model, slp_samples): @@ -130,7 +139,7 @@ def log_weight(rng_key, i, slp_model, slp_samples): guide = AutoNormal( slp_model, init_loc_fn=init_to_value(values=trace), - init_scale=1.0, + init_scale=self.proposal_scale, ) rng_key, subkey = random.split(rng_key) guide_trace = seed(guide, subkey)(*args, **kwargs) diff --git a/test/contrib/stochastic_support/test_dcc.py b/test/contrib/stochastic_support/test_dcc.py index 2fc24da42..da64640b1 100644 --- a/test/contrib/stochastic_support/test_dcc.py +++ b/test/contrib/stochastic_support/test_dcc.py @@ -63,6 +63,34 @@ def model(): dcc.run(subkey) +@pytest.mark.parametrize("proposal_scale", [0.1, 1.0, 10.0]) +def test_proposal_scale(proposal_scale): + def model(y): + z = numpyro.sample("z", dist.Normal(0.0, 1.0)) + model1 = numpyro.sample( + "model1", dist.Bernoulli(0.5), infer={"branching": True} + ) + sigma = 1.0 if model1 == 0 else 2.0 + with numpyro.plate("data", y.shape[0]): + numpyro.sample("obs", dist.Normal(z, sigma), obs=y) + + rng_key = random.PRNGKey(0) + + rng_key, subkey = random.split(rng_key) + y_train = dist.Normal(0, 1).sample(subkey, (200,)) + + mcmc_kwargs = dict( + num_warmup=50, + num_samples=50, + num_chains=2, + progress_bar=False, + ) + + dcc = DCC(model, mcmc_kwargs=mcmc_kwargs, proposal_scale=proposal_scale) + rng_key, subkey = random.split(rng_key) + dcc.run(subkey, y_train) + + @pytest.mark.parametrize( "chain_method", ["sequential", "parallel", "vectorized"], From e60215f4840e2791b941169acff2935e9576d25e Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Fri, 16 Feb 2024 19:48:45 +0000 Subject: [PATCH 4/6] Run linter --- numpyro/contrib/stochastic_support/dcc.py | 2 +- test/contrib/stochastic_support/test_dcc.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/numpyro/contrib/stochastic_support/dcc.py b/numpyro/contrib/stochastic_support/dcc.py index 08e1382fc..7f7c68109 100644 --- a/numpyro/contrib/stochastic_support/dcc.py +++ b/numpyro/contrib/stochastic_support/dcc.py @@ -4,8 +4,8 @@ from collections import OrderedDict, namedtuple import jax -import jax.numpy as jnp from jax import random +import jax.numpy as jnp import numpyro.distributions as dist from numpyro.handlers import condition, seed, trace diff --git a/test/contrib/stochastic_support/test_dcc.py b/test/contrib/stochastic_support/test_dcc.py index da64640b1..d56e39b98 100644 --- a/test/contrib/stochastic_support/test_dcc.py +++ b/test/contrib/stochastic_support/test_dcc.py @@ -1,14 +1,18 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + import math -import jax -import jax.numpy as jnp +from numpy.testing import assert_allclose import pytest + +import jax from jax import random -from numpy.testing import assert_allclose +import jax.numpy as jnp import numpyro -import numpyro.distributions as dist from numpyro.contrib.stochastic_support.dcc import DCC +import numpyro.distributions as dist from numpyro.infer import HMC, NUTS, SA, BarkerMH From b09a7f7f9b4f881834f27388828e7ff2bbf18ee8 Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Mon, 19 Feb 2024 20:18:05 +0000 Subject: [PATCH 5/6] Add __init__.py file and allow parallel inference in tests --- .github/workflows/ci.yml | 2 +- numpyro/contrib/stochastic_support/__init__.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) create mode 100644 numpyro/contrib/stochastic_support/__init__.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c4f99fa16..522c49da6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -102,7 +102,7 @@ jobs: run: | pytest -vs --durations=20 test/infer/test_mcmc.py pytest -vs --durations=20 test/infer --ignore=test/infer/test_mcmc.py - pytest -vs --durations=20 test/contrib + XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs --durations=20 test/contrib - name: Test x64 run: | JAX_ENABLE_X64=1 pytest -vs test/infer/test_mcmc.py -k x64 diff --git a/numpyro/contrib/stochastic_support/__init__.py b/numpyro/contrib/stochastic_support/__init__.py new file mode 100644 index 000000000..6c1dc37f9 --- /dev/null +++ b/numpyro/contrib/stochastic_support/__init__.py @@ -0,0 +1,8 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +from numpyro.contrib.stochastic_support.dcc import DCC + +__all__ = [ + "DCC", +] From 00068d082778867768e9adffaeca2a7e7d808795 Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Wed, 21 Feb 2024 17:18:56 +0000 Subject: [PATCH 6/6] Move DCC tests to 'test chains' group --- .github/workflows/ci.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 522c49da6..47fabe4e7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -102,7 +102,7 @@ jobs: run: | pytest -vs --durations=20 test/infer/test_mcmc.py pytest -vs --durations=20 test/infer --ignore=test/infer/test_mcmc.py - XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs --durations=20 test/contrib + pytest -vs --durations=20 test/contrib --ignore=test/contrib/stochastic_support/test_dcc.py - name: Test x64 run: | JAX_ENABLE_X64=1 pytest -vs test/infer/test_mcmc.py -k x64 @@ -110,6 +110,7 @@ jobs: run: | XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_mcmc.py -k "chain or pmap or vmap" XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/test_tfp.py -k "chain" + XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/stochastic_support/test_dcc.py XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_hmc_gibbs.py -k "chain" - name: Test custom prng run: |