Skip to content

Commit

Permalink
Add AutoLowRankMultivariateNormal.
Browse files Browse the repository at this point in the history
  • Loading branch information
tillahoffmann committed Feb 17, 2024
1 parent ff3fbf7 commit 45b95f8
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 6 deletions.
51 changes: 51 additions & 0 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from numpyro.util import find_stack_level, not_jax_tracer

__all__ = [
"AutoBatchedLowRankMultivariateNormal",
"AutoBatchedMultivariateNormal",
"AutoContinuous",
"AutoGuide",
Expand Down Expand Up @@ -1989,6 +1990,56 @@ def quantiles(self, params, quantiles):
return self._unpack_and_constrain(latent, params)


class AutoBatchedLowRankMultivariateNormal(AutoBatchedMixin, AutoContinuous):
"""
This implementation of :class:`AutoContinuous` uses a batched
AutoLowRankMultivariateNormal distribution to construct a guide over the entire
latent space. The guide does not depend on the model's ``*args, **kwargs``.
Usage::
guide = AutoBatchedLowRankMultivariateNormal(model, batch_ndim=1, ...)
svi = SVI(model, guide, ...)
"""

scale_constraint = constraints.softplus_positive

def __init__(
self,
model,
*,
prefix="auto",
init_loc_fn=init_to_uniform,
init_scale=0.1,
rank=None,
batch_ndim=1,
):
if init_scale <= 0:
raise ValueError("Expected init_scale > 0. but got {}".format(init_scale))
self._init_scale = init_scale
self.rank = rank
super().__init__(
model, prefix=prefix, init_loc_fn=init_loc_fn, batch_ndim=batch_ndim,
)

def _get_batched_posterior(self):
rank = int(round(self._event_shape[0]**0.5)) if self.rank is None else self.rank
init_latent = self._init_latent.reshape(self._batch_shape + self._event_shape)
loc = numpyro.param("{}_loc".format(self.prefix), init_latent)
cov_factor = numpyro.param(
"{}_cov_factor".format(self.prefix),
jnp.zeros(self._batch_shape + self._event_shape + (rank,))
)
scale = numpyro.param(
"{}_scale".format(self.prefix),
jnp.full(self._batch_shape + self._event_shape, self._init_scale),
constraint=self.scale_constraint,
)
cov_diag = scale * scale
cov_factor = cov_factor * scale[..., None]
return dist.LowRankMultivariateNormal(loc, cov_factor, cov_diag)


class AutoLaplaceApproximation(AutoContinuous):
r"""
Laplace approximation (quadratic approximation) approximates the posterior
Expand Down
27 changes: 21 additions & 6 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from numpyro.handlers import substitute
from numpyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO
from numpyro.infer.autoguide import (
AutoBatchedLowRankMultivariateNormal,
AutoBatchedMultivariateNormal,
AutoBNAFNormal,
AutoDAIS,
Expand Down Expand Up @@ -1254,7 +1255,14 @@ def model():
)


def test_auto_batched() -> None:
@pytest.mark.parametrize(
"auto_class",
[
AutoBatchedMultivariateNormal,
AutoBatchedLowRankMultivariateNormal,
],
)
def test_auto_batched(auto_class) -> None:
# Model for batched multivariate normal.
off_diag = jnp.asarray([-0.2, 0, 0.5])
covs = off_diag[:, None, None] + jnp.eye(4)
Expand All @@ -1264,7 +1272,7 @@ def model():
numpyro.sample("x", dist.MultivariateNormal(0, covs))

# Run inference.
guide = AutoBatchedMultivariateNormal(model)
guide = auto_class(model)
svi = SVI(model, guide, optax.adam(0.001), Trace_ELBO())
result = svi.run(random.PRNGKey(0), 10000)
samples = guide.sample_posterior(
Expand All @@ -1279,7 +1287,14 @@ def model():
assert corrcoef > 0.99


def test_auto_batched_shapes() -> None:
@pytest.mark.parametrize(
"auto_class",
[
AutoBatchedMultivariateNormal,
AutoBatchedLowRankMultivariateNormal,
],
)
def test_auto_batched_shapes(auto_class) -> None:
def model(n, m):
distribution = dist.Normal().expand([7]).to_event(1)
with numpyro.plate("n", n):
Expand All @@ -1289,10 +1304,10 @@ def model(n, m):
return x, y

with numpyro.handlers.seed(rng_seed=0):
AutoBatchedMultivariateNormal(model)(3, 3)
auto_class(model)(3, 3)

with pytest.raises(ValueError, match="inconsistent batch shapes"):
AutoBatchedMultivariateNormal(model)(3, 4)
auto_class(model)(3, 4)

with pytest.raises(ValueError, match="Expected 3 batch dimensions"):
AutoBatchedMultivariateNormal(model, batch_ndim=3)(3, 3)
auto_class(model, batch_ndim=3)(3, 3)

0 comments on commit 45b95f8

Please sign in to comment.