From b6c1c670da7379fa7a179c95f385b28c1220a0e1 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Fri, 24 Jan 2025 15:42:43 -0500 Subject: [PATCH 1/4] Fetch `rng_key` using `prng_key` message in `block` handler. --- numpyro/handlers.py | 15 +++++++++--- test/test_handlers.py | 56 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 3 deletions(-) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 1f739d5b6..d8abaaa24 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -310,8 +310,17 @@ def __init__( super(block, self).__init__(fn) def process_message(self, msg: Message) -> None: - if self.hide_fn(msg): - msg["stop"] = True + if not self.hide_fn(msg): + return + msg["stop"] = True + + # For specific message types, get a prng key from the stack. These types match + # the implementation in `seed` except `prng_key` because it would lead to + # infinite recursion. The corresponding message reaches the seed handler unless + # messages of `prng_key` type are blocked. + allowed_types = {"sample", "plate", "control_flow"} + if msg["type"] in allowed_types and msg["kwargs"]["rng_key"] is None: + msg["kwargs"]["rng_key"] = numpyro.prng_key() class collapse(trace): @@ -748,7 +757,7 @@ class seed(Messenger): :param fn: Python callable with NumPyro primitives. :param rng_seed: a random number generator seed. :type rng_seed: int, jnp.ndarray scalar, or jax.random.PRNGKey - :param list hide_types: an optional list of side types to skip seeding, e.g. ['plate']. + :param list hide_types: an optional list of site types to skip seeding, e.g. ['plate']. .. note:: diff --git a/test/test_handlers.py b/test/test_handlers.py index 15121eb46..65e188843 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -16,6 +16,7 @@ import numpyro from numpyro import handlers +from numpyro.contrib import control_flow import numpyro.distributions as dist from numpyro.distributions import constraints from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO @@ -617,6 +618,61 @@ def test_block_expose(): assert "x" in trace and "mu" not in trace and "sigma" in trace +@pytest.mark.parametrize( + "block_config, expected_sites", + [ + ({"hide": ["y"]}, {"x", "z", "n", "cluster", "a", "b"}), + ({"expose_types": ["prng_key"]}, set()), + ({"hide": ["n"]}, {"x", "y", "z", "cluster", "a", "b"}), + ({"hide": ["cluster", "b"]}, {"x", "y", "z", "n", "a"}), + pytest.param( + {"expose": ["x", "z"]}, + {"x", "z"}, + marks=pytest.mark.xfail( + reason="Exposing by name blocks prng_key messages." + ), + ), + ], +) +def test_block_seed(block_config: dict, expected_sites: set) -> None: + def fn(): + sample = {} + sample["x"] = numpyro.sample("x", dist.Normal()) + sample["y"] = numpyro.sample("y", dist.Normal(sample["x"])) + with numpyro.plate("n", 10, subsample_size=7) as sample["idx"]: + sample["z"] = numpyro.sample("z", dist.Normal(sample["y"])) + + def true_fun(_): + a = numpyro.sample("a", dist.Normal(4.0)) + b = numpyro.deterministic("b", a - 2.0) + return a, b + + def false_fun(_): + a = numpyro.sample("a", dist.Normal(0.0)) + b = numpyro.deterministic("b", a) + return a, b + + sample["cluster"] = numpyro.sample("cluster", dist.Normal()) + sample["a"], sample["b"] = control_flow.cond( + sample["cluster"] > 0, true_fun, false_fun, None + ) + return sample + + blocked_seeded = handlers.block(handlers.seed(fn, rng_seed=17), **block_config) + with handlers.trace() as trace1: + sample1 = blocked_seeded() + assert set(trace1) == expected_sites + + seeded_blocked = handlers.seed(handlers.block(fn, **block_config), rng_seed=17) + with handlers.trace() as trace2: + sample2 = seeded_blocked() + assert set(trace2) == expected_sites + + # Verify that the sample values are identical. + for key, value in sample1.items(): + assert jnp.allclose(value, sample2[key]) + + def test_scope(): def fn(): with numpyro.plate("N", 10): From 16267060440b28a33b15f8c89d5a0a65f5476af8 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Fri, 24 Jan 2025 18:45:40 -0500 Subject: [PATCH 2/4] Do not block site if nameless in `scan`. --- numpyro/contrib/control_flow/scan.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/numpyro/contrib/control_flow/scan.py b/numpyro/contrib/control_flow/scan.py index b756455e9..901093768 100644 --- a/numpyro/contrib/control_flow/scan.py +++ b/numpyro/contrib/control_flow/scan.py @@ -199,7 +199,9 @@ def body_fn(wrapped_carry, x, prefix=None): return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y) with ( - handlers.block(hide_fn=lambda site: not site["name"].startswith("_PREV_")), + handlers.block( + hide_fn=lambda site: not site.get("name", "nameless").startswith("_PREV_") + ), enum(first_available_dim=first_available_dim), ): wrapped_carry = (0, rng_key, init) From 28c996e5e806e019ee0010ca5f7fa840f82bf31c Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Fri, 24 Jan 2025 18:47:01 -0500 Subject: [PATCH 3/4] Remove redundant seeding for auto guides. --- numpyro/infer/autoguide.py | 2 +- test/infer/test_autoguide.py | 24 +++++------------------- 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 694f7c8b3..aba1ba583 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -227,7 +227,7 @@ class AutoGuideList(AutoGuide): guide = AutoGuideList(my_model) guide.append( AutoNormal( - numpyro.handlers.block(numpyro.handlers.seed(model, rng_seed=0), hide=["coefs"]) + numpyro.handlers.block(model, hide=["coefs"]) ) ) guide.append( diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index ec26d0ae9..f4dd9dd97 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -1068,13 +1068,7 @@ def model(x, y=None): ValueError, match="AutoDAIS, AutoSemiDAIS, and AutoSurrogateLikelihoodDAIS are not supported.", ): - guide.append( - auto_classes[1]( - numpyro.handlers.block( - numpyro.handlers.seed(model, rng_seed=1), hide=["a"] - ) - ) - ) + guide.append(auto_classes[1](numpyro.handlers.block(model, hide=["a"]))) return if auto_classes[1] == AutoSemiDAIS: with pytest.raises( @@ -1083,9 +1077,7 @@ def model(x, y=None): ): guide.append( auto_classes[1]( - numpyro.handlers.block( - numpyro.handlers.seed(model, rng_seed=1), hide=["a"] - ), + numpyro.handlers.block(model, hide=["a"]), local_model=None, global_guide=None, ) @@ -1098,19 +1090,13 @@ def model(x, y=None): ): guide.append( auto_classes[1]( - numpyro.handlers.block( - numpyro.handlers.seed(model, rng_seed=1), hide=["a"] - ), + numpyro.handlers.block(model, hide=["a"]), surrogate_model=None, ) ) return - guide.append( - auto_classes[1]( - numpyro.handlers.block(numpyro.handlers.seed(model, rng_seed=1), hide=["a"]) - ) - ) + guide.append(auto_classes[1](numpyro.handlers.block(model, hide=["a"]))) optimiser = numpyro.optim.Adam(step_size=0.1) svi = SVI(model, guide, optimiser, Elbo()) @@ -1195,7 +1181,7 @@ def model(): numpyro.deterministic("x2", x**2) guide = AutoGuideList(model) - blocked_model = handlers.block(handlers.seed(model, 7), hide=["x2"]) + blocked_model = handlers.block(model, hide=["x2"]) # AutoGuideList does not support AutoDAIS, AutoSemiDAIS, or AutoSurrogateLikelihoodDAIS if auto_class == AutoDAIS: From c0ac274ad0ba7db5a82cbc5f84fe415d4c749ee4 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Mon, 27 Jan 2025 11:13:17 -0500 Subject: [PATCH 4/4] Do not block `prng_key` messages. --- numpyro/handlers.py | 17 +++++++++++------ numpyro/infer/autoguide.py | 17 +++++------------ test/infer/test_autoguide.py | 8 +------- test/test_handlers.py | 8 +------- 4 files changed, 18 insertions(+), 32 deletions(-) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index d8abaaa24..ab5d5f72c 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -310,16 +310,21 @@ def __init__( super(block, self).__init__(fn) def process_message(self, msg: Message) -> None: - if not self.hide_fn(msg): + if not self.hide_fn(msg) or msg["type"] == "prng_key": return msg["stop"] = True - # For specific message types, get a prng key from the stack. These types match - # the implementation in `seed` except `prng_key` because it would lead to - # infinite recursion. The corresponding message reaches the seed handler unless - # messages of `prng_key` type are blocked. + # For specific message types, get a prng key from the stack if no key or value + # is available yet. These types match the implementation in `seed` except + # `prng_key` because it would lead to infinite recursion. The corresponding + # message reaches the seed handler because we always let messages of `prng_key` + # propagate. allowed_types = {"sample", "plate", "control_flow"} - if msg["type"] in allowed_types and msg["kwargs"]["rng_key"] is None: + if ( + msg["type"] in allowed_types + and msg["value"] is None + and msg["kwargs"]["rng_key"] is None + ): msg["kwargs"]["rng_key"] = numpyro.prng_key() diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index aba1ba583..3454fc37a 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -232,7 +232,7 @@ class AutoGuideList(AutoGuide): ) guide.append( AutoDelta( - numpyro.handlers.block(numpyro.handlers.seed(model, rng_seed=1), expose=["coefs"]) + numpyro.handlers.block(model, expose=["coefs"]) ) ) svi = SVI(model, guide, optim, Trace_ELBO()) @@ -1403,7 +1403,7 @@ def _setup_prototype(self, *args, **kwargs): self._local_plate = (plate_name, plate_full_size, plate_subsample_size) if self.global_guide is not None: - with handlers.block(), handlers.seed(rng_seed=0): + with handlers.block(): local_args = (self.global_guide.model(*args, **kwargs),) local_kwargs = {} else: @@ -1411,11 +1411,11 @@ def _setup_prototype(self, *args, **kwargs): local_kwargs = kwargs.copy() if self.local_guide is not None: - with handlers.block(), handlers.trace() as tr, handlers.seed(rng_seed=0): + with handlers.block(), handlers.trace() as tr: self.local_guide(*local_args, **local_kwargs) self.prototype_local_guide_trace = tr - with handlers.block(), handlers.trace() as tr, handlers.seed(rng_seed=0): + with handlers.block(), handlers.trace() as tr: self.local_model(*local_args, **local_kwargs) self.prototype_local_model_trace = tr @@ -1462,12 +1462,7 @@ def _sample_latent(self, *args, **kwargs): if self.global_guide is not None: global_latents = self.global_guide(*args, **kwargs) - rng_key = numpyro.prng_key() - with ( - handlers.block(), - handlers.seed(rng_seed=rng_key), - handlers.substitute(data=global_latents), - ): + with handlers.block(), handlers.substitute(data=global_latents): global_outputs = self.global_guide.model(*args, **kwargs) local_args = (global_outputs,) local_kwargs = {} @@ -1575,12 +1570,10 @@ def fn(x): local_kwargs["_subsample_idx"] = {plate_name: idx} if self.local_guide is not None: - key = numpyro.prng_key() subsample_guide = partial(_subsample_model, self.local_guide) with ( handlers.block(), handlers.trace() as tr, - handlers.seed(rng_seed=key), handlers.substitute(data=local_guide_params), ): with warnings.catch_warnings(): diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index f4dd9dd97..61be7f317 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -1055,13 +1055,7 @@ def model(x, y=None): y = a + x @ b + sigma * random.normal(random.PRNGKey(1), (N, 1)) guide = AutoGuideList(model) - guide.append( - auto_classes[0]( - numpyro.handlers.block( - numpyro.handlers.seed(model, rng_seed=0), expose=["a"] - ) - ) - ) + guide.append(auto_classes[0](numpyro.handlers.block(model, expose=["a"]))) # AutoGuideList does not support AutoDAIS, AutoSemiDAIS, or AutoSurrogateLikelihoodDAIS if auto_classes[1] == AutoDAIS: with pytest.raises( diff --git a/test/test_handlers.py b/test/test_handlers.py index 65e188843..dbf7229b1 100644 --- a/test/test_handlers.py +++ b/test/test_handlers.py @@ -625,13 +625,7 @@ def test_block_expose(): ({"expose_types": ["prng_key"]}, set()), ({"hide": ["n"]}, {"x", "y", "z", "cluster", "a", "b"}), ({"hide": ["cluster", "b"]}, {"x", "y", "z", "n", "a"}), - pytest.param( - {"expose": ["x", "z"]}, - {"x", "z"}, - marks=pytest.mark.xfail( - reason="Exposing by name blocks prng_key messages." - ), - ), + ({"expose": ["x", "z"]}, {"x", "z"}), ], ) def test_block_seed(block_config: dict, expected_sites: set) -> None: