Skip to content

Commit

Permalink
Fetch rng_key using prng_key message in block handler. (#1957)
Browse files Browse the repository at this point in the history
* Fetch `rng_key` using `prng_key` message in `block` handler.

* Do not block site if nameless in `scan`.

* Remove redundant seeding for auto guides.

* Do not block `prng_key` messages.
  • Loading branch information
tillahoffmann authored Jan 30, 2025
1 parent 93e11c2 commit 5f3bdd1
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 43 deletions.
4 changes: 3 additions & 1 deletion numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ def body_fn(wrapped_carry, x, prefix=None):
return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)

with (
handlers.block(hide_fn=lambda site: not site["name"].startswith("_PREV_")),
handlers.block(
hide_fn=lambda site: not site.get("name", "nameless").startswith("_PREV_")
),
enum(first_available_dim=first_available_dim),
):
wrapped_carry = (0, rng_key, init)
Expand Down
20 changes: 17 additions & 3 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,22 @@ def __init__(
super(block, self).__init__(fn)

def process_message(self, msg: Message) -> None:
if self.hide_fn(msg):
msg["stop"] = True
if not self.hide_fn(msg) or msg["type"] == "prng_key":
return
msg["stop"] = True

# For specific message types, get a prng key from the stack if no key or value
# is available yet. These types match the implementation in `seed` except
# `prng_key` because it would lead to infinite recursion. The corresponding
# message reaches the seed handler because we always let messages of `prng_key`
# propagate.
allowed_types = {"sample", "plate", "control_flow"}
if (
msg["type"] in allowed_types
and msg["value"] is None
and msg["kwargs"]["rng_key"] is None
):
msg["kwargs"]["rng_key"] = numpyro.prng_key()


class collapse(trace):
Expand Down Expand Up @@ -748,7 +762,7 @@ class seed(Messenger):
:param fn: Python callable with NumPyro primitives.
:param rng_seed: a random number generator seed.
:type rng_seed: int, jnp.ndarray scalar, or jax.random.PRNGKey
:param list hide_types: an optional list of side types to skip seeding, e.g. ['plate'].
:param list hide_types: an optional list of site types to skip seeding, e.g. ['plate'].
.. note::
Expand Down
19 changes: 6 additions & 13 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,12 @@ class AutoGuideList(AutoGuide):
guide = AutoGuideList(my_model)
guide.append(
AutoNormal(
numpyro.handlers.block(numpyro.handlers.seed(model, rng_seed=0), hide=["coefs"])
numpyro.handlers.block(model, hide=["coefs"])
)
)
guide.append(
AutoDelta(
numpyro.handlers.block(numpyro.handlers.seed(model, rng_seed=1), expose=["coefs"])
numpyro.handlers.block(model, expose=["coefs"])
)
)
svi = SVI(model, guide, optim, Trace_ELBO())
Expand Down Expand Up @@ -1403,19 +1403,19 @@ def _setup_prototype(self, *args, **kwargs):
self._local_plate = (plate_name, plate_full_size, plate_subsample_size)

if self.global_guide is not None:
with handlers.block(), handlers.seed(rng_seed=0):
with handlers.block():
local_args = (self.global_guide.model(*args, **kwargs),)
local_kwargs = {}
else:
local_args = args
local_kwargs = kwargs.copy()

if self.local_guide is not None:
with handlers.block(), handlers.trace() as tr, handlers.seed(rng_seed=0):
with handlers.block(), handlers.trace() as tr:
self.local_guide(*local_args, **local_kwargs)
self.prototype_local_guide_trace = tr

with handlers.block(), handlers.trace() as tr, handlers.seed(rng_seed=0):
with handlers.block(), handlers.trace() as tr:
self.local_model(*local_args, **local_kwargs)
self.prototype_local_model_trace = tr

Expand Down Expand Up @@ -1462,12 +1462,7 @@ def _sample_latent(self, *args, **kwargs):

if self.global_guide is not None:
global_latents = self.global_guide(*args, **kwargs)
rng_key = numpyro.prng_key()
with (
handlers.block(),
handlers.seed(rng_seed=rng_key),
handlers.substitute(data=global_latents),
):
with handlers.block(), handlers.substitute(data=global_latents):
global_outputs = self.global_guide.model(*args, **kwargs)
local_args = (global_outputs,)
local_kwargs = {}
Expand Down Expand Up @@ -1575,12 +1570,10 @@ def fn(x):

local_kwargs["_subsample_idx"] = {plate_name: idx}
if self.local_guide is not None:
key = numpyro.prng_key()
subsample_guide = partial(_subsample_model, self.local_guide)
with (
handlers.block(),
handlers.trace() as tr,
handlers.seed(rng_seed=key),
handlers.substitute(data=local_guide_params),
):
with warnings.catch_warnings():
Expand Down
32 changes: 6 additions & 26 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,26 +1055,14 @@ def model(x, y=None):
y = a + x @ b + sigma * random.normal(random.PRNGKey(1), (N, 1))

guide = AutoGuideList(model)
guide.append(
auto_classes[0](
numpyro.handlers.block(
numpyro.handlers.seed(model, rng_seed=0), expose=["a"]
)
)
)
guide.append(auto_classes[0](numpyro.handlers.block(model, expose=["a"])))
# AutoGuideList does not support AutoDAIS, AutoSemiDAIS, or AutoSurrogateLikelihoodDAIS
if auto_classes[1] == AutoDAIS:
with pytest.raises(
ValueError,
match="AutoDAIS, AutoSemiDAIS, and AutoSurrogateLikelihoodDAIS are not supported.",
):
guide.append(
auto_classes[1](
numpyro.handlers.block(
numpyro.handlers.seed(model, rng_seed=1), hide=["a"]
)
)
)
guide.append(auto_classes[1](numpyro.handlers.block(model, hide=["a"])))
return
if auto_classes[1] == AutoSemiDAIS:
with pytest.raises(
Expand All @@ -1083,9 +1071,7 @@ def model(x, y=None):
):
guide.append(
auto_classes[1](
numpyro.handlers.block(
numpyro.handlers.seed(model, rng_seed=1), hide=["a"]
),
numpyro.handlers.block(model, hide=["a"]),
local_model=None,
global_guide=None,
)
Expand All @@ -1098,19 +1084,13 @@ def model(x, y=None):
):
guide.append(
auto_classes[1](
numpyro.handlers.block(
numpyro.handlers.seed(model, rng_seed=1), hide=["a"]
),
numpyro.handlers.block(model, hide=["a"]),
surrogate_model=None,
)
)
return

guide.append(
auto_classes[1](
numpyro.handlers.block(numpyro.handlers.seed(model, rng_seed=1), hide=["a"])
)
)
guide.append(auto_classes[1](numpyro.handlers.block(model, hide=["a"])))

optimiser = numpyro.optim.Adam(step_size=0.1)
svi = SVI(model, guide, optimiser, Elbo())
Expand Down Expand Up @@ -1195,7 +1175,7 @@ def model():
numpyro.deterministic("x2", x**2)

guide = AutoGuideList(model)
blocked_model = handlers.block(handlers.seed(model, 7), hide=["x2"])
blocked_model = handlers.block(model, hide=["x2"])

# AutoGuideList does not support AutoDAIS, AutoSemiDAIS, or AutoSurrogateLikelihoodDAIS
if auto_class == AutoDAIS:
Expand Down
50 changes: 50 additions & 0 deletions test/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpyro
from numpyro import handlers
from numpyro.contrib import control_flow
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
Expand Down Expand Up @@ -617,6 +618,55 @@ def test_block_expose():
assert "x" in trace and "mu" not in trace and "sigma" in trace


@pytest.mark.parametrize(
"block_config, expected_sites",
[
({"hide": ["y"]}, {"x", "z", "n", "cluster", "a", "b"}),
({"expose_types": ["prng_key"]}, set()),
({"hide": ["n"]}, {"x", "y", "z", "cluster", "a", "b"}),
({"hide": ["cluster", "b"]}, {"x", "y", "z", "n", "a"}),
({"expose": ["x", "z"]}, {"x", "z"}),
],
)
def test_block_seed(block_config: dict, expected_sites: set) -> None:
def fn():
sample = {}
sample["x"] = numpyro.sample("x", dist.Normal())
sample["y"] = numpyro.sample("y", dist.Normal(sample["x"]))
with numpyro.plate("n", 10, subsample_size=7) as sample["idx"]:
sample["z"] = numpyro.sample("z", dist.Normal(sample["y"]))

def true_fun(_):
a = numpyro.sample("a", dist.Normal(4.0))
b = numpyro.deterministic("b", a - 2.0)
return a, b

def false_fun(_):
a = numpyro.sample("a", dist.Normal(0.0))
b = numpyro.deterministic("b", a)
return a, b

sample["cluster"] = numpyro.sample("cluster", dist.Normal())
sample["a"], sample["b"] = control_flow.cond(
sample["cluster"] > 0, true_fun, false_fun, None
)
return sample

blocked_seeded = handlers.block(handlers.seed(fn, rng_seed=17), **block_config)
with handlers.trace() as trace1:
sample1 = blocked_seeded()
assert set(trace1) == expected_sites

seeded_blocked = handlers.seed(handlers.block(fn, **block_config), rng_seed=17)
with handlers.trace() as trace2:
sample2 = seeded_blocked()
assert set(trace2) == expected_sites

# Verify that the sample values are identical.
for key, value in sample1.items():
assert jnp.allclose(value, sample2[key])


def test_scope():
def fn():
with numpyro.plate("N", 10):
Expand Down

0 comments on commit 5f3bdd1

Please sign in to comment.