Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add two ensemble sampling methods #1692

Merged
merged 25 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion docs/source/mcmc.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ We provide a high-level overview of the MCMC algorithms in NumPyro:
* `BarkerMH <https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.barker.BarkerMH>`_ is a gradient-based MCMC method that may be competitive with HMC and NUTS for some models. It is applicable to models with continuous latent variables.
* `HMCGibbs <https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.hmc_gibbs.HMCGibbs>`_ combines HMC/NUTS steps with custom Gibbs updates. Gibbs updates must be specified by the user.
* `DiscreteHMCGibbs <https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.hmc_gibbs.DiscreteHMCGibbs>`_ combines HMC/NUTS steps with Gibbs updates for discrete latent variables. The corresponding Gibbs updates are computed automatically.
* `SA <https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.sa.SA>`_ is the only MCMC method in NumPyro that does not leverage gradients. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities. Note that SA generally requires a *very* large number of samples, as mixing tends to be slow. On the plus side individual steps can be fast.
* `SA <https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.sa.SA>`_ does not leverage gradients. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities. Note that SA generally requires a *very* large number of samples, as mixing tends to be slow. On the plus side individual steps can be fast.
* `AIES <https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.ensemble.AIES>`_ is a gradient-free ensemble method that informs Metropolis-Hastings proposals by sharing information between chains. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate. It may be a good choice for models with non-differentiable log densities, and can be robust to likelihood-free models. AIES generally requires the number of chains to be twice as large as the number of latent parameters, (and ideally larger).
* `ESS <https://num.pyro.ai/en/latest/mcmc.html#numpyro.infer.ensemble.ESS>`_ is a gradient-free ensemble method that shares information between chains to find good slice sampling directions. It tends to be more sample efficient than AIES. It is only applicable to models with continuous latent variables. It is expected to perform best for models whose latent dimension is low to moderate and may be a good choice for models with non-differentiable log densities. ESS generally requires the number of chains to be twice as large as the number of latent parameters, (and ideally larger).

Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete latent variables if possible (see `restrictions <https://pyro.ai/examples/enumeration.html#Restriction-1:-conditional-independence>`_). Enumerated sites need to be marked with `infer={'enumerate': 'parallel'}` like in the `annotation example <https://num.pyro.ai/en/stable/examples/annotation.html>`_.

Expand Down Expand Up @@ -101,6 +103,30 @@ SA
:show-inheritance:
:member-order: bysource

EnsembleSampler
^^^^^^^^^^^^^^^
.. autoclass:: numpyro.infer.ensemble.EnsembleSampler
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

AIES
^^^^
.. autoclass:: numpyro.infer.ensemble.AIES
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

ESS
^^^
.. autoclass:: numpyro.infer.ensemble.ESS
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource

.. autofunction:: numpyro.infer.hmc.hmc

.. autofunction:: numpyro.infer.hmc.hmc.init_kernel
Expand All @@ -117,6 +143,12 @@ SA

.. autodata:: numpyro.infer.sa.SAState

.. autodata:: numpyro.infer.ensemble.EnsembleSamplerState

.. autodata:: numpyro.infer.ensemble.AIESState

.. autodata:: numpyro.infer.ensemble.ESSState


TensorFlow Kernels
------------------
Expand Down
3 changes: 3 additions & 0 deletions numpyro/infer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
TraceGraph_ELBO,
TraceMeanField_ELBO,
)
from numpyro.infer.ensemble import AIES, ESS
from numpyro.infer.hmc import HMC, NUTS
from numpyro.infer.hmc_gibbs import HMCECS, DiscreteHMCGibbs, HMCGibbs
from numpyro.infer.initialization import (
Expand All @@ -29,6 +30,7 @@
from . import autoguide, reparam

__all__ = [
"AIES",
"autoguide",
"init_to_feasible",
"init_to_mean",
Expand All @@ -41,6 +43,7 @@
"BarkerMH",
"DiscreteHMCGibbs",
"ELBO",
"ESS",
"HMC",
"HMCECS",
"HMCGibbs",
Expand Down
Loading
Loading