diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index b756455e9..901093768 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -199,7 +199,9 @@ def body_fn(wrapped_carry, x, prefix=None): return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y) with ( - handlers.block(hide_fn=lambda site: not site["name"].startswith("_PREV_")), + handlers.block( + hide_fn=lambda site: not site.get("name", "nameless").startswith("_PREV_") + ), enum(first_available_dim=first_available_dim), ): wrapped_carry = (0, rng_key, init) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 1f739d5b6..ab5d5f72c 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -310,8 +310,22 @@ def __init__( super(block, self).__init__(fn) def process_message(self, msg: Message) -> None: - if self.hide_fn(msg): - msg["stop"] = True + if not self.hide_fn(msg) or msg["type"] == "prng_key": + return + msg["stop"] = True + + # For specific message types, get a prng key from the stack if no key or value + # is available yet. These types match the implementation in `seed` except + # `prng_key` because it would lead to infinite recursion. The corresponding + # message reaches the seed handler because we always let messages of `prng_key` + # propagate. + allowed_types = {"sample", "plate", "control_flow"} + if ( + msg["type"] in allowed_types + and msg["value"] is None + and msg["kwargs"]["rng_key"] is None + ): + msg["kwargs"]["rng_key"] = numpyro.prng_key() class collapse(trace): @@ -748,7 +762,7 @@ class seed(Messenger): :param fn: Python callable with NumPyro primitives. :param rng_seed: a random number generator seed. :type rng_seed: int, jnp.ndarray scalar, or jax.random.PRNGKey - :param list hide_types: an optional list of side types to skip seeding, e.g. ['plate']. + :param list hide_types: an optional list of site types to skip seeding, e.g. ['plate']. .. note:: diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 694f7c8b3..3454fc37a 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -227,12 +227,12 @@ class AutoGuideList(AutoGuide): guide = AutoGuideList(my_model) guide.append( AutoNormal( - numpyro.handlers.block(numpyro.handlers.seed(model, rng_seed=0), hide=["coefs"]) + numpyro.handlers.block(model, hide=["coefs"]) ) ) guide.append( AutoDelta( - numpyro.handlers.block(numpyro.handlers.seed(model, rng_seed=1), expose=["coefs"]) + numpyro.handlers.block(model, expose=["coefs"]) ) ) svi = SVI(model, guide, optim, Trace_ELBO()) @@ -1403,7 +1403,7 @@ def _setup_prototype(self, *args, **kwargs): self._local_plate = (plate_name, plate_full_size, plate_subsample_size) if self.global_guide is not None: - with handlers.block(), handlers.seed(rng_seed=0): + with handlers.block(): local_args = (self.global_guide.model(*args, **kwargs),) local_kwargs = {} else: @@ -1411,11 +1411,11 @@ def _setup_prototype(self, *args, **kwargs): local_kwargs = kwargs.copy() if self.local_guide is not None: - with handlers.block(), handlers.trace() as tr, handlers.seed(rng_seed=0): + with handlers.block(), handlers.trace() as tr: self.local_guide(*local_args, **local_kwargs) self.prototype_local_guide_trace = tr - with handlers.block(), handlers.trace() as tr, handlers.seed(rng_seed=0): + with handlers.block(), handlers.trace() as tr: self.local_model(*local_args, **local_kwargs) self.prototype_local_model_trace = tr @@ -1462,12 +1462,7 @@ def _sample_latent(self, *args, **kwargs): if self.global_guide is not None: global_latents = self.global_guide(*args, **kwargs) - rng_key = numpyro.prng_key() - with ( - handlers.block(), - handlers.seed(rng_seed=rng_key), - handlers.substitute(data=global_latents), - ): + with handlers.block(), handlers.substitute(data=global_latents): global_outputs = self.global_guide.model(*args, **kwargs) local_args = (global_outputs,) local_kwargs = {} @@ -1575,12 +1570,10 @@ def fn(x): local_kwargs["_subsample_idx"] = {plate_name: idx} if self.local_guide is not None: - key = numpyro.prng_key() subsample_guide = partial(_subsample_model, self.local_guide) with ( handlers.block(), handlers.trace() as tr, - handlers.seed(rng_seed=key), handlers.substitute(data=local_guide_params), ): with warnings.catch_warnings(): diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index ec26d0ae9..61be7f317 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -1055,26 +1055,14 @@ def model(x, y=None): y = a + x @ b + sigma * random.normal(random.PRNGKey(1), (N, 1)) guide = AutoGuideList(model) - guide.append( - auto_classes[0]( - numpyro.handlers.block( - numpyro.handlers.seed(model, rng_seed=0), expose=["a"] - ) - ) - ) + guide.append(auto_classes[0](numpyro.handlers.block(model, expose=["a"]))) # AutoGuideList does not support AutoDAIS, AutoSemiDAIS, or AutoSurrogateLikelihoodDAIS if auto_classes[1] == AutoDAIS: with pytest.raises( ValueError, match="AutoDAIS, AutoSemiDAIS, and AutoSurrogateLikelihoodDAIS are not supported.", ): - guide.append( - auto_classes[1]( - numpyro.handlers.block( - numpyro.handlers.seed(model, rng_seed=1), hide=["a"] - ) - ) - ) + guide.append(auto_classes[1](numpyro.handlers.block(model, hide=["a"]))) return if auto_classes[1] == AutoSemiDAIS: with pytest.raises( @@ -1083,9 +1071,7 @@ def model(x, y=None): ): guide.append( auto_classes[1]( - numpyro.handlers.block( - numpyro.handlers.seed(model, rng_seed=1), hide=["a"] - ), + numpyro.handlers.block(model, hide=["a"]), local_model=None, global_guide=None, ) @@ -1098,19 +1084,13 @@ def model(x, y=None): ): guide.append( auto_classes[1]( - numpyro.handlers.block( - numpyro.handlers.seed(model, rng_seed=1), hide=["a"] - ), + numpyro.handlers.block(model, hide=["a"]), surrogate_model=None, ) ) return - guide.append( - auto_classes[1]( - numpyro.handlers.block(numpyro.handlers.seed(model, rng_seed=1), hide=["a"]) - ) - ) + guide.append(auto_classes[1](numpyro.handlers.block(model, hide=["a"]))) optimiser = numpyro.optim.Adam(step_size=0.1) svi = SVI(model, guide, optimiser, Elbo()) @@ -1195,7 +1175,7 @@ def model(): numpyro.deterministic("x2", x**2) guide = AutoGuideList(model) - blocked_model = handlers.block(handlers.seed(model, 7), hide=["x2"]) + blocked_model = handlers.block(model, hide=["x2"]) # AutoGuideList does not support AutoDAIS, AutoSemiDAIS, or AutoSurrogateLikelihoodDAIS if auto_class == AutoDAIS: diff --git a/test/test_handlers.py b/test/test_handlers.py index 15121eb46..dbf7229b1 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -16,6 +16,7 @@ import numpyro from numpyro import handlers +from numpyro.contrib import control_flow import numpyro.distributions as dist from numpyro.distributions import constraints from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO @@ -617,6 +618,55 @@ def test_block_expose(): assert "x" in trace and "mu" not in trace and "sigma" in trace +@pytest.mark.parametrize( + "block_config, expected_sites", + [ + ({"hide": ["y"]}, {"x", "z", "n", "cluster", "a", "b"}), + ({"expose_types": ["prng_key"]}, set()), + ({"hide": ["n"]}, {"x", "y", "z", "cluster", "a", "b"}), + ({"hide": ["cluster", "b"]}, {"x", "y", "z", "n", "a"}), + ({"expose": ["x", "z"]}, {"x", "z"}), + ], +) +def test_block_seed(block_config: dict, expected_sites: set) -> None: + def fn(): + sample = {} + sample["x"] = numpyro.sample("x", dist.Normal()) + sample["y"] = numpyro.sample("y", dist.Normal(sample["x"])) + with numpyro.plate("n", 10, subsample_size=7) as sample["idx"]: + sample["z"] = numpyro.sample("z", dist.Normal(sample["y"])) + + def true_fun(_): + a = numpyro.sample("a", dist.Normal(4.0)) + b = numpyro.deterministic("b", a - 2.0) + return a, b + + def false_fun(_): + a = numpyro.sample("a", dist.Normal(0.0)) + b = numpyro.deterministic("b", a) + return a, b + + sample["cluster"] = numpyro.sample("cluster", dist.Normal()) + sample["a"], sample["b"] = control_flow.cond( + sample["cluster"] > 0, true_fun, false_fun, None + ) + return sample + + blocked_seeded = handlers.block(handlers.seed(fn, rng_seed=17), **block_config) + with handlers.trace() as trace1: + sample1 = blocked_seeded() + assert set(trace1) == expected_sites + + seeded_blocked = handlers.seed(handlers.block(fn, **block_config), rng_seed=17) + with handlers.trace() as trace2: + sample2 = seeded_blocked() + assert set(trace2) == expected_sites + + # Verify that the sample values are identical. + for key, value in sample1.items(): + assert jnp.allclose(value, sample2[key]) + + def test_scope(): def fn(): with numpyro.plate("N", 10):