diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index 0548e2cc4..26f45038c 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -132,7 +132,7 @@ def __call__(self, *args, **kwargs): raise NotImplementedError @abstractmethod - def sample_posterior(self, rng_key, params, *, sample_shape=()): + def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs): """ Generate samples from the approximate posterior over the latent sites in the model. @@ -141,7 +141,9 @@ def sample_posterior(self, rng_key, params, *, sample_shape=()): :param dict params: Current parameters of model and autoguide. The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params` method from :class:`~numpyro.infer.svi.SVI`. + :param args: Arguments to be provided to the model / guide. :param tuple sample_shape: sample shape of each latent site, defaults to (). + :param kwargs: Keyword arguments to be provided to the model / guide. :return: a dict containing samples drawn the this guide. :rtype: dict """ @@ -317,18 +319,11 @@ def __iter__(self): def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs): result = {} for part in self._guides: - # TODO: remove this when sample_posterior() signatures are consistent - # we know part is not AutoDAIS, AutoSemiDAIS, or AutoSurrogateLikelihoodDAIS - if isinstance(part, numpyro.infer.autoguide.AutoDelta): - result.update( - part.sample_posterior( - rng_key, params, *args, sample_shape=sample_shape, **kwargs - ) - ) - else: - result.update( - part.sample_posterior(rng_key, params, sample_shape=sample_shape) + result.update( + part.sample_posterior( + rng_key, params, *args, sample_shape=sample_shape, **kwargs ) + ) return result def median(self, params): @@ -469,7 +464,7 @@ def _constrain(self, latent_samples): else: return self._postprocess_fn(latent_samples) - def sample_posterior(self, rng_key, params, *, sample_shape=()): + def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs): locs = {k: params["{}_{}_loc".format(k, self.prefix)] for k in self._init_locs} scales = {k: params["{}_{}_scale".format(k, self.prefix)] for k in locs} with handlers.seed(rng_seed=rng_key): @@ -810,7 +805,7 @@ def get_posterior(self, params): transform = self.get_transform(params) return dist.TransformedDistribution(base_dist, transform) - def sample_posterior(self, rng_key, params, *, sample_shape=()): + def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs): latent_sample = handlers.substitute( handlers.seed(self._sample_latent, rng_key), params )(sample_shape=sample_shape) @@ -999,7 +994,7 @@ def scan_body(carry, eps_beta): return z - def sample_posterior(self, rng_key, params, *, sample_shape=()): + def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs): def _single_sample(_rng_key): latent_sample = handlers.substitute( handlers.seed(self._sample_latent, _rng_key), params @@ -2175,7 +2170,7 @@ def get_posterior(self, params): transform = self.get_transform(params) return dist.MultivariateNormal(transform.loc, scale_tril=transform.scale_tril) - def sample_posterior(self, rng_key, params, *, sample_shape=()): + def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs): latent_sample = self.get_posterior(params).sample(rng_key, sample_shape) return self._unpack_and_constrain(latent_sample, params)