Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fetch rng_key using prng_key message in block handler. #1957

Merged
merged 4 commits into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ def body_fn(wrapped_carry, x, prefix=None):
return (i + 1, rng_key, new_carry), (PytreeTrace(trace), y)

with (
handlers.block(hide_fn=lambda site: not site["name"].startswith("_PREV_")),
handlers.block(
hide_fn=lambda site: not site.get("name", "nameless").startswith("_PREV_")
),
enum(first_available_dim=first_available_dim),
):
wrapped_carry = (0, rng_key, init)
Expand Down
20 changes: 17 additions & 3 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,22 @@ def __init__(
super(block, self).__init__(fn)

def process_message(self, msg: Message) -> None:
if self.hide_fn(msg):
msg["stop"] = True
if not self.hide_fn(msg) or msg["type"] == "prng_key":
return
msg["stop"] = True

# For specific message types, get a prng key from the stack if no key or value
# is available yet. These types match the implementation in `seed` except
# `prng_key` because it would lead to infinite recursion. The corresponding
# message reaches the seed handler because we always let messages of `prng_key`
# propagate.
allowed_types = {"sample", "plate", "control_flow"}
if (
msg["type"] in allowed_types
and msg["value"] is None
and msg["kwargs"]["rng_key"] is None
):
msg["kwargs"]["rng_key"] = numpyro.prng_key()


class collapse(trace):
Expand Down Expand Up @@ -748,7 +762,7 @@ class seed(Messenger):
:param fn: Python callable with NumPyro primitives.
:param rng_seed: a random number generator seed.
:type rng_seed: int, jnp.ndarray scalar, or jax.random.PRNGKey
:param list hide_types: an optional list of side types to skip seeding, e.g. ['plate'].
:param list hide_types: an optional list of site types to skip seeding, e.g. ['plate'].

.. note::

Expand Down
19 changes: 6 additions & 13 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,12 @@ class AutoGuideList(AutoGuide):
guide = AutoGuideList(my_model)
guide.append(
AutoNormal(
numpyro.handlers.block(numpyro.handlers.seed(model, rng_seed=0), hide=["coefs"])
numpyro.handlers.block(model, hide=["coefs"])
)
)
guide.append(
AutoDelta(
numpyro.handlers.block(numpyro.handlers.seed(model, rng_seed=1), expose=["coefs"])
numpyro.handlers.block(model, expose=["coefs"])
)
)
svi = SVI(model, guide, optim, Trace_ELBO())
Expand Down Expand Up @@ -1403,19 +1403,19 @@ def _setup_prototype(self, *args, **kwargs):
self._local_plate = (plate_name, plate_full_size, plate_subsample_size)

if self.global_guide is not None:
with handlers.block(), handlers.seed(rng_seed=0):
with handlers.block():
local_args = (self.global_guide.model(*args, **kwargs),)
local_kwargs = {}
else:
local_args = args
local_kwargs = kwargs.copy()

if self.local_guide is not None:
with handlers.block(), handlers.trace() as tr, handlers.seed(rng_seed=0):
with handlers.block(), handlers.trace() as tr:
self.local_guide(*local_args, **local_kwargs)
self.prototype_local_guide_trace = tr

with handlers.block(), handlers.trace() as tr, handlers.seed(rng_seed=0):
with handlers.block(), handlers.trace() as tr:
self.local_model(*local_args, **local_kwargs)
self.prototype_local_model_trace = tr

Expand Down Expand Up @@ -1462,12 +1462,7 @@ def _sample_latent(self, *args, **kwargs):

if self.global_guide is not None:
global_latents = self.global_guide(*args, **kwargs)
rng_key = numpyro.prng_key()
with (
handlers.block(),
handlers.seed(rng_seed=rng_key),
handlers.substitute(data=global_latents),
):
with handlers.block(), handlers.substitute(data=global_latents):
global_outputs = self.global_guide.model(*args, **kwargs)
local_args = (global_outputs,)
local_kwargs = {}
Expand Down Expand Up @@ -1575,12 +1570,10 @@ def fn(x):

local_kwargs["_subsample_idx"] = {plate_name: idx}
if self.local_guide is not None:
key = numpyro.prng_key()
subsample_guide = partial(_subsample_model, self.local_guide)
with (
handlers.block(),
handlers.trace() as tr,
handlers.seed(rng_seed=key),
handlers.substitute(data=local_guide_params),
):
with warnings.catch_warnings():
Expand Down
32 changes: 6 additions & 26 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,26 +1055,14 @@ def model(x, y=None):
y = a + x @ b + sigma * random.normal(random.PRNGKey(1), (N, 1))

guide = AutoGuideList(model)
guide.append(
auto_classes[0](
numpyro.handlers.block(
numpyro.handlers.seed(model, rng_seed=0), expose=["a"]
)
)
)
guide.append(auto_classes[0](numpyro.handlers.block(model, expose=["a"])))
# AutoGuideList does not support AutoDAIS, AutoSemiDAIS, or AutoSurrogateLikelihoodDAIS
if auto_classes[1] == AutoDAIS:
with pytest.raises(
ValueError,
match="AutoDAIS, AutoSemiDAIS, and AutoSurrogateLikelihoodDAIS are not supported.",
):
guide.append(
auto_classes[1](
numpyro.handlers.block(
numpyro.handlers.seed(model, rng_seed=1), hide=["a"]
)
)
)
guide.append(auto_classes[1](numpyro.handlers.block(model, hide=["a"])))
return
if auto_classes[1] == AutoSemiDAIS:
with pytest.raises(
Expand All @@ -1083,9 +1071,7 @@ def model(x, y=None):
):
guide.append(
auto_classes[1](
numpyro.handlers.block(
numpyro.handlers.seed(model, rng_seed=1), hide=["a"]
),
numpyro.handlers.block(model, hide=["a"]),
local_model=None,
global_guide=None,
)
Expand All @@ -1098,19 +1084,13 @@ def model(x, y=None):
):
guide.append(
auto_classes[1](
numpyro.handlers.block(
numpyro.handlers.seed(model, rng_seed=1), hide=["a"]
),
numpyro.handlers.block(model, hide=["a"]),
surrogate_model=None,
)
)
return

guide.append(
auto_classes[1](
numpyro.handlers.block(numpyro.handlers.seed(model, rng_seed=1), hide=["a"])
)
)
guide.append(auto_classes[1](numpyro.handlers.block(model, hide=["a"])))

optimiser = numpyro.optim.Adam(step_size=0.1)
svi = SVI(model, guide, optimiser, Elbo())
Expand Down Expand Up @@ -1195,7 +1175,7 @@ def model():
numpyro.deterministic("x2", x**2)

guide = AutoGuideList(model)
blocked_model = handlers.block(handlers.seed(model, 7), hide=["x2"])
blocked_model = handlers.block(model, hide=["x2"])

# AutoGuideList does not support AutoDAIS, AutoSemiDAIS, or AutoSurrogateLikelihoodDAIS
if auto_class == AutoDAIS:
Expand Down
50 changes: 50 additions & 0 deletions test/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpyro
from numpyro import handlers
from numpyro.contrib import control_flow
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
Expand Down Expand Up @@ -617,6 +618,55 @@ def test_block_expose():
assert "x" in trace and "mu" not in trace and "sigma" in trace


@pytest.mark.parametrize(
"block_config, expected_sites",
[
({"hide": ["y"]}, {"x", "z", "n", "cluster", "a", "b"}),
({"expose_types": ["prng_key"]}, set()),
({"hide": ["n"]}, {"x", "y", "z", "cluster", "a", "b"}),
({"hide": ["cluster", "b"]}, {"x", "y", "z", "n", "a"}),
({"expose": ["x", "z"]}, {"x", "z"}),
],
)
def test_block_seed(block_config: dict, expected_sites: set) -> None:
def fn():
sample = {}
sample["x"] = numpyro.sample("x", dist.Normal())
sample["y"] = numpyro.sample("y", dist.Normal(sample["x"]))
with numpyro.plate("n", 10, subsample_size=7) as sample["idx"]:
sample["z"] = numpyro.sample("z", dist.Normal(sample["y"]))

def true_fun(_):
a = numpyro.sample("a", dist.Normal(4.0))
b = numpyro.deterministic("b", a - 2.0)
return a, b

def false_fun(_):
a = numpyro.sample("a", dist.Normal(0.0))
b = numpyro.deterministic("b", a)
return a, b

sample["cluster"] = numpyro.sample("cluster", dist.Normal())
sample["a"], sample["b"] = control_flow.cond(
sample["cluster"] > 0, true_fun, false_fun, None
)
return sample

blocked_seeded = handlers.block(handlers.seed(fn, rng_seed=17), **block_config)
with handlers.trace() as trace1:
sample1 = blocked_seeded()
assert set(trace1) == expected_sites

seeded_blocked = handlers.seed(handlers.block(fn, **block_config), rng_seed=17)
with handlers.trace() as trace2:
sample2 = seeded_blocked()
assert set(trace2) == expected_sites

# Verify that the sample values are identical.
for key, value in sample1.items():
assert jnp.allclose(value, sample2[key])


def test_scope():
def fn():
with numpyro.plate("N", 10):
Expand Down
Loading