From 8a672691a17c2691324350e71c42bdaec0ab1f02 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Fri, 17 Jan 2025 04:55:21 -0500 Subject: [PATCH] Make `Delta.log_prob` jit-able on metal and raise explicit error for `Delta` sampling site during initialization. (#1950) * Use `jnp.where` for `Delta.log_prob` (cf. jax-ml/jax#25935). * Raise explicit error for unobserved `Delta` sample sites during intialization. --- numpyro/distributions/distribution.py | 2 +- numpyro/infer/util.py | 13 +++++++++++++ test/infer/test_autoguide.py | 15 +++++++++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 0d701ba0a..3aa6e69a2 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -1250,7 +1250,7 @@ def sample(self, key, sample_shape=()): @validate_sample def log_prob(self, value): - log_prob = jnp.log(value == self.v) + log_prob = jnp.where(value == self.v, 0, -jnp.inf) log_prob = sum_rightmost(log_prob, len(self.event_shape)) return log_prob + self.log_density diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index a3b7425d0..fe39a87bf 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -17,6 +17,7 @@ import jax.numpy as jnp import numpyro +from numpyro import distributions as dist from numpyro.distributions import constraints from numpyro.distributions.transforms import biject_to from numpyro.distributions.util import is_identically_one, sum_rightmost @@ -685,6 +686,18 @@ def initialize_model( has_enumerate_support, model_trace, ) = _get_model_transforms(substituted_model, model_args, model_kwargs) + + for name, site in model_trace.items(): + if ( + site["type"] == "sample" + and isinstance(site["fn"], dist.Delta) + and not site["is_observed"] + ): + raise ValueError( + f"Sample site '{name}' has a delta distribution; use " + "`numpyro.deterministic` to add this value to the trace instead." + ) + # substitute param sites from model_trace to model so # we don't need to generate again parameters of `numpyro.module` model = substitute( diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index 0e2e53aa4..ec26d0ae9 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -1350,3 +1350,18 @@ def model(): assert jnp.allclose( jnp.stack(list(median.values())).ravel(), params["auto_loc"].ravel() ) + + +def test_autoguide_with_delta_site() -> None: + def model(x): + numpyro.sample("x", dist.Delta(3.0), obs=x) + # Need to sample a latent variable so the guide is not empty. + numpyro.sample("y", dist.Normal()) + + guide = AutoDiagonalNormal(lambda: model(None)) + with pytest.raises(ValueError, match="has a delta distribution"): + numpyro.handlers.seed(guide, 9)() + + # Check delta distributions are fine if observed. + guide = AutoDiagonalNormal(lambda: model(3.0)) + numpyro.handlers.seed(guide, 9)()