Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of DCC inference algorithm #1715

Merged
merged 7 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/contrib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,11 @@ SteinVI Kernels
.. autoclass:: numpyro.contrib.einstein.stein_kernels.ProbabilityProductKernel


Stochastic Support
~~~~~~~~~~~~~~~~~~

.. autoclass:: numpyro.contrib.stochastic_support.dcc.DCC
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
180 changes: 180 additions & 0 deletions numpyro/contrib/stochastic_support/dcc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from collections import OrderedDict, namedtuple

import jax
import jax.numpy as jnp
from jax import random

import numpyro.distributions as dist
from numpyro.handlers import condition, seed, trace
from numpyro.infer import MCMC, NUTS
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.util import init_to_value, log_density

DCCResult = namedtuple("DCCResult", ["samples", "slp_weights"])


class DCC:
"""
Implements the Divide, Conquer, and Combine (DCC) algorithm for models with
stochastic support from [1].

.. note:: This implementation assumes that all stochastic branching is done based on the
outcomes of discrete sampling sites that are annotated with `infer={"branching": True}`.
For example,

.. code-block:: python

def model():
model1 = numpyro.sample("model1", dist.Bernoulli(0.5), infer={"branching": True})
if model1 == 0:
mean = numpyro.sample("a1", dist.Normal(0.0, 1.0))
else:
mean = numpyro.sample("a2", dist.Normal(1.0, 1.0))
numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2)



**References:**

1. *Divide, Conquer, and Combine: a New Inference Strategy for Probabilistic Programs with Stochastic Support*,
Yuan Zhou, Hongseok Yang, Yee Whye Teh, Tom Rainforth

:param model: Python callable containing Pyro primitives :mod:`~numpyro.primitives`.
:param dict mcmc_kwargs: Dictionary of arguments passed to :data:`~numpyro.infer.MCMC`.
:param numpyro.infer.mcmc.MCMCKernel kernel_cls: MCMC kernel class that is used for
local inference. Defaults to :class:`~numpyro.infer.NUTS`.
:param int num_slp_samples: Number of samples to draw from the prior to discover the
straight-line programs (SLPs).
:param int max_slps: Maximum number of SLPs to discover. DCC will not run inference
on more than `max_slps`.
:param float proposal_scale: Scale parameter for the proposal distribution for
estimating the normalization constant of an SLP.
"""

def __init__(
self,
model,
mcmc_kwargs,
kernel_cls=NUTS,
num_slp_samples=1000,
max_slps=124,
proposal_scale=1.0,
):
self.model = model
self.kernel_cls = kernel_cls
self.mcmc_kwargs = mcmc_kwargs

self.num_slp_samples = num_slp_samples
self.max_slps = max_slps
self.proposal_scale = proposal_scale

def _find_slps(self, rng_key, *args, **kwargs):
"""
Discover the straight-line programs (SLPs) in the model by sampling from the prior.
This implementation assumes that all branching is done via discrete sampling sites
that are annotated with `infer={"branching": True}`.
"""
branching_traces = {}
for _ in range(self.num_slp_samples):
rng_key, subkey = random.split(rng_key)
tr = trace(seed(self.model, subkey)).get_trace(*args, **kwargs)
btr = self._get_branching_trace(tr)
btr_str = ",".join(str(x) for x in btr.values())
if btr_str not in branching_traces:
branching_traces[btr_str] = btr
if len(branching_traces) >= self.max_slps:
break

return branching_traces

def _get_branching_trace(self, tr):
"""
Extract the sites from the trace that are annotated with `infer={"branching": True}`.
"""
branching_trace = OrderedDict()
for site in tr.values():
if site["type"] == "sample" and site["infer"].get("branching", False):
if (
not isinstance(site["fn"], dist.Distribution)
or not site["fn"].support.is_discrete
):
raise RuntimeError(
"Branching is only supported for discrete sampling sites."
)
# It is essential that we convert the value to a Python int. If it remains
# a JAX Array, then during JIT compilation it will be treated as an AbstractArray
# which means branching will raise in an error.
# Reference: (https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-jit)
branching_trace[site["name"]] = int(site["value"])
return branching_trace

def _run_mcmc(self, rng_key, branching_trace, *args, **kwargs):
"""
Run MCMC on the model conditioned on the given branching trace.
"""
slp_model = condition(self.model, data=branching_trace)
kernel = self.kernel_cls(slp_model)
mcmc = MCMC(kernel, **self.mcmc_kwargs)
mcmc.run(rng_key, *args, **kwargs)

return mcmc.get_samples()

def _combine_samples(self, rng_key, samples, branching_traces, *args, **kwargs):
"""
Weight each SLP proportional to its estimated normalization constant.
The normalization constants are estimated using importance sampling with
the proposal centred on the MCMC samples. This is a special case of the
layered adaptive importance sampling algorithm from [1].

**References:**
1. *Layered adaptive importance sampling*,
Luca Martino, Victor Elvira, David Luengo, and Jukka Corander.
"""

def log_weight(rng_key, i, slp_model, slp_samples):
trace = {k: v[i] for k, v in slp_samples.items()}
guide = AutoNormal(
slp_model,
init_loc_fn=init_to_value(values=trace),
init_scale=self.proposal_scale,
)
rng_key, subkey = random.split(rng_key)
guide_trace = seed(guide, subkey)(*args, **kwargs)
guide_log_density, _ = log_density(guide, args, kwargs, guide_trace)
model_log_density, _ = log_density(slp_model, args, kwargs, guide_trace)
return model_log_density - guide_log_density

log_weights = jax.vmap(log_weight, in_axes=(None, 0, None, None))

log_Zs = {}
for bt, slp_samples in samples.items():
num_samples = slp_samples[next(iter(slp_samples))].shape[0]
slp_model = condition(self.model, data=branching_traces[bt])
lws = log_weights(rng_key, jnp.arange(num_samples), slp_model, slp_samples)
log_Zs[bt] = jax.scipy.special.logsumexp(lws) - jnp.log(num_samples)

normalizer = jax.scipy.special.logsumexp(jnp.array(list(log_Zs.values())))
slp_weights = {k: jnp.exp(v - normalizer) for k, v in log_Zs.items()}
return DCCResult(samples, slp_weights)

def run(self, rng_key, *args, **kwargs):
"""
Run DCC and collect samples for all SLPs.

:param jax.random.PRNGKey rng_key: Random number generator key.
:param args: Arguments to the model.
:param kwargs: Keyword arguments to the model.
"""
rng_key, subkey = random.split(rng_key)
branching_traces = self._find_slps(subkey, *args, **kwargs)

samples = dict()
for key, bt in branching_traces.items():
rng_key, subkey = random.split(rng_key)
samples[key] = self._run_mcmc(subkey, bt, *args, **kwargs)

rng_key, subkey = random.split(rng_key)
return self._combine_samples(subkey, samples, branching_traces, *args, **kwargs)
197 changes: 197 additions & 0 deletions test/contrib/stochastic_support/test_dcc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import math

import jax
import jax.numpy as jnp
import pytest
from jax import random
from numpy.testing import assert_allclose

import numpyro
import numpyro.distributions as dist
from numpyro.contrib.stochastic_support.dcc import DCC
from numpyro.infer import HMC, NUTS, SA, BarkerMH


@pytest.mark.parametrize(
"branch_dist",
[dist.Normal(0, 1), dist.Gamma(1, 1)],
)
@pytest.mark.xfail(raises=RuntimeError)
def test_continuous_branching(branch_dist):
rng_key = random.PRNGKey(0)

def model():
model1 = numpyro.sample("model1", branch_dist, infer={"branching": True})
mean = 1.0 if model1 == 0 else 2.0
numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2)

mcmc_kwargs = dict(
num_warmup=500,
num_samples=1000,
num_chains=1,
)

dcc = DCC(model, mcmc_kwargs=mcmc_kwargs)
rng_key, subkey = random.split(rng_key)
dcc.run(subkey)


def test_different_address_path():
rng_key = random.PRNGKey(0)

def model():
model1 = numpyro.sample(
"model1", dist.Bernoulli(0.5), infer={"branching": True}
)
if model1 == 0:
numpyro.sample("a1", dist.Normal(9.0, 1.0))
else:
numpyro.sample("a2", dist.Normal(9.0, 1.0))
numpyro.sample("a3", dist.Normal(9.0, 1.0))
mean = 1.0 if model1 == 0 else 2.0
numpyro.sample("obs", dist.Normal(mean, 1.0), obs=0.2)

mcmc_kwargs = dict(
num_warmup=50,
num_samples=50,
num_chains=1,
progress_bar=False,
)

dcc = DCC(model, mcmc_kwargs=mcmc_kwargs)
rng_key, subkey = random.split(rng_key)
dcc.run(subkey)


@pytest.mark.parametrize("proposal_scale", [0.1, 1.0, 10.0])
def test_proposal_scale(proposal_scale):
def model(y):
z = numpyro.sample("z", dist.Normal(0.0, 1.0))
model1 = numpyro.sample(
"model1", dist.Bernoulli(0.5), infer={"branching": True}
)
sigma = 1.0 if model1 == 0 else 2.0
with numpyro.plate("data", y.shape[0]):
numpyro.sample("obs", dist.Normal(z, sigma), obs=y)

rng_key = random.PRNGKey(0)

rng_key, subkey = random.split(rng_key)
y_train = dist.Normal(0, 1).sample(subkey, (200,))

mcmc_kwargs = dict(
num_warmup=50,
num_samples=50,
num_chains=2,
progress_bar=False,
)

dcc = DCC(model, mcmc_kwargs=mcmc_kwargs, proposal_scale=proposal_scale)
rng_key, subkey = random.split(rng_key)
dcc.run(subkey, y_train)


@pytest.mark.parametrize(
"chain_method",
["sequential", "parallel", "vectorized"],
)
@pytest.mark.parametrize("kernel_cls", [NUTS, HMC, SA, BarkerMH])
def test_kernels(chain_method, kernel_cls):
if chain_method == "vectorized" and kernel_cls in [SA, BarkerMH]:
# These methods do not support vectorized execution.
return

def model(y):
z = numpyro.sample("z", dist.Normal(0.0, 1.0))
model1 = numpyro.sample(
"model1", dist.Bernoulli(0.5), infer={"branching": True}
)
sigma = 1.0 if model1 == 0 else 2.0
with numpyro.plate("data", y.shape[0]):
numpyro.sample("obs", dist.Normal(z, sigma), obs=y)

rng_key = random.PRNGKey(0)

rng_key, subkey = random.split(rng_key)
y_train = dist.Normal(0, 1).sample(subkey, (200,))

mcmc_kwargs = dict(
num_warmup=50,
num_samples=50,
num_chains=2,
chain_method=chain_method,
progress_bar=False,
)

dcc = DCC(model, mcmc_kwargs=mcmc_kwargs, kernel_cls=kernel_cls)
rng_key, subkey = random.split(rng_key)
dcc.run(subkey, y_train)


def test_weight_convergence():
PRIOR_MEAN, PRIOR_STD = 0.0, 1.0
LIKELIHOOD1_STD = 2.0
LIKELIHOOD2_STD = 0.62177

def log_marginal_likelihood(data, likelihood_std, prior_mean, prior_std):
"""
Calculate the marginal likelihood of a model with Normal likelihood, unknown mean,
and Normal prior.

Taken from Section 2.5 at https://www.cs.ubc.ca/~murphyk/Papers/bayesGauss.pdf.
"""
num_data = data.shape[0]
likelihood_var = jnp.square(likelihood_std)
prior_var = jnp.square(prior_std)

first_term = (
jnp.log(likelihood_std)
- num_data * jnp.log(jnp.sqrt(2 * math.pi) * likelihood_std)
+ 0.5 * jnp.log(num_data * prior_var + likelihood_var)
)
second_term = -(jnp.sum(jnp.square(data)) / (2 * likelihood_var)) - (
jnp.square(prior_mean) / (2 * prior_var)
)
third_term = (
(
prior_var
* jnp.square(num_data)
* jnp.square(jnp.mean(data))
/ likelihood_var
)
+ (likelihood_var * jnp.square(prior_mean) / prior_var)
+ 2 * num_data * jnp.mean(data) * prior_mean
) / (2 * (num_data * prior_var + likelihood_var))
return first_term + second_term + third_term

def model(y):
z = numpyro.sample("z", dist.Normal(PRIOR_MEAN, PRIOR_STD))
model1 = numpyro.sample(
"model1", dist.Bernoulli(0.5), infer={"branching": True}
)
sigma = LIKELIHOOD1_STD if model1 == 0 else LIKELIHOOD2_STD
with numpyro.plate("data", y.shape[0]):
numpyro.sample("obs", dist.Normal(z, sigma), obs=y)

rng_key = random.PRNGKey(0)

rng_key, subkey = random.split(rng_key)
y_train = dist.Normal(0, 1).sample(subkey, (200,))

mcmc_kwargs = dict(
num_warmup=500,
num_samples=1000,
num_chains=1,
)

dcc = DCC(model, mcmc_kwargs=mcmc_kwargs)
rng_key, subkey = random.split(rng_key)
dcc_result = dcc.run(subkey, y_train)
slp_weights = jnp.array([dcc_result.slp_weights["0"], dcc_result.slp_weights["1"]])
assert_allclose(1.0, jnp.sum(slp_weights))

slp1_lml = log_marginal_likelihood(y_train, LIKELIHOOD1_STD, PRIOR_MEAN, PRIOR_STD)
slp2_lml = log_marginal_likelihood(y_train, LIKELIHOOD2_STD, PRIOR_MEAN, PRIOR_STD)
lmls = jnp.array([slp1_lml, slp2_lml])
analytic_weights = jnp.exp(lmls - jax.scipy.special.logsumexp(lmls))
assert_allclose(analytic_weights, slp_weights, rtol=1e-5, atol=1e-8)
Loading