diff --git a/numpyro/infer/inspect.py b/numpyro/infer/inspect.py index 2b97f4021..5232dfe11 100644 --- a/numpyro/infer/inspect.py +++ b/numpyro/infer/inspect.py @@ -385,7 +385,7 @@ def process_message(self, msg): samples = { name: site["value"] for name, site in trace.items() - if (site["type"] == "sample" and not site["is_observed"]) + if site["type"] == "sample" or site["type"] == "deterministic" }