Skip to content

Commit

Permalink
Unify sample_posterior() signatures (#1979)
Browse files Browse the repository at this point in the history
  • Loading branch information
tare authored Feb 17, 2025
1 parent fa2ecb3 commit af68b5f
Showing 1 changed file with 11 additions and 16 deletions.
27 changes: 11 additions & 16 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __call__(self, *args, **kwargs):
raise NotImplementedError

@abstractmethod
def sample_posterior(self, rng_key, params, *, sample_shape=()):
def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs):
"""
Generate samples from the approximate posterior over the latent
sites in the model.
Expand All @@ -141,7 +141,9 @@ def sample_posterior(self, rng_key, params, *, sample_shape=()):
:param dict params: Current parameters of model and autoguide.
The parameters can be obtained using :meth:`~numpyro.infer.svi.SVI.get_params`
method from :class:`~numpyro.infer.svi.SVI`.
:param args: Arguments to be provided to the model / guide.
:param tuple sample_shape: sample shape of each latent site, defaults to ().
:param kwargs: Keyword arguments to be provided to the model / guide.
:return: a dict containing samples drawn the this guide.
:rtype: dict
"""
Expand Down Expand Up @@ -317,18 +319,11 @@ def __iter__(self):
def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs):
result = {}
for part in self._guides:
# TODO: remove this when sample_posterior() signatures are consistent
# we know part is not AutoDAIS, AutoSemiDAIS, or AutoSurrogateLikelihoodDAIS
if isinstance(part, numpyro.infer.autoguide.AutoDelta):
result.update(
part.sample_posterior(
rng_key, params, *args, sample_shape=sample_shape, **kwargs
)
)
else:
result.update(
part.sample_posterior(rng_key, params, sample_shape=sample_shape)
result.update(
part.sample_posterior(
rng_key, params, *args, sample_shape=sample_shape, **kwargs
)
)
return result

def median(self, params):
Expand Down Expand Up @@ -469,7 +464,7 @@ def _constrain(self, latent_samples):
else:
return self._postprocess_fn(latent_samples)

def sample_posterior(self, rng_key, params, *, sample_shape=()):
def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs):
locs = {k: params["{}_{}_loc".format(k, self.prefix)] for k in self._init_locs}
scales = {k: params["{}_{}_scale".format(k, self.prefix)] for k in locs}
with handlers.seed(rng_seed=rng_key):
Expand Down Expand Up @@ -810,7 +805,7 @@ def get_posterior(self, params):
transform = self.get_transform(params)
return dist.TransformedDistribution(base_dist, transform)

def sample_posterior(self, rng_key, params, *, sample_shape=()):
def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs):
latent_sample = handlers.substitute(
handlers.seed(self._sample_latent, rng_key), params
)(sample_shape=sample_shape)
Expand Down Expand Up @@ -999,7 +994,7 @@ def scan_body(carry, eps_beta):

return z

def sample_posterior(self, rng_key, params, *, sample_shape=()):
def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs):
def _single_sample(_rng_key):
latent_sample = handlers.substitute(
handlers.seed(self._sample_latent, _rng_key), params
Expand Down Expand Up @@ -2175,7 +2170,7 @@ def get_posterior(self, params):
transform = self.get_transform(params)
return dist.MultivariateNormal(transform.loc, scale_tril=transform.scale_tril)

def sample_posterior(self, rng_key, params, *, sample_shape=()):
def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs):
latent_sample = self.get_posterior(params).sample(rng_key, sample_shape)
return self._unpack_and_constrain(latent_sample, params)

Expand Down

0 comments on commit af68b5f

Please sign in to comment.