Skip to content

Commit

Permalink
Create documentation and github workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
horaceg authored and rlouf committed Jan 19, 2022
1 parent 6f728a9 commit 094748a
Show file tree
Hide file tree
Showing 18 changed files with 383 additions and 62 deletions.
33 changes: 33 additions & 0 deletions .github/workflows/build_doc.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: Publish docs

on:
push:
branches:
- main

jobs:
build-and-deploy:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/[email protected]
with:
persist-credentials: false

- name: Set up Python 3.7
uses: actions/setup-python@v1
with:
python-version: 3.7

- name: Build docs
run: |
pip install -r requirements-dev.txt
sphinx-build -b html docs docs/build/html
- name: Publish docs
uses: JamesIves/[email protected]
with:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
BRANCH: gh-pages
FOLDER: docs/build/html
CLEAN: true
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ pip-delete-this-directory.txt

# Sphinx documentation
docs/_build/
docs/_autosummary

# pyenv
.python-version
Expand Down
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ test:
publish:
git tag -a $(LIB_VERSION) -m $(LIB_VERSION)
git push --tag


build-docs:
sphinx-build -b html docs docs/_build/html
41 changes: 21 additions & 20 deletions blackjax/adaptation/step_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def dual_averaging_adaptation(
the error at time t. We would like to find a procedure that adapts the
value of :math:`\\epsilon` such that :math:`h(x) =\\mathbb{E}\\left[H_t|\\epsilon\\right] = 0`
Following [1]_, the authors of [2]_ proposed the following update scheme. If
we note :math:``x = \\log \\epsilon` we follow:
Following [Nesterov2009]_, the authors of [Hoffman2014]_ proposed the following update scheme. If
we note :math:`x = \\log \\epsilon` we follow:
.. math:
x_{t+1} \\LongLeftArrow \\mu - \\frac{\\sqrt{t}}{\\gamma} \\frac{1}{t+t_0} \\sum_{i=1}^t H_i
Expand All @@ -74,21 +74,21 @@ def dual_averaging_adaptation(
:math:`h(\\overline{x}_t)` converges to 0, i.e. the Metropolis acceptance
rate converges to the desired rate.
See reference [2]_ (section 3.2.1) for a detailed discussion.
See reference [Hoffman2014]_ (section 3.2.1) for a detailed discussion.
Parameters
----------
t0: float >= 0
Free parameter that stabilizes the initial iterations of the algorithm.
Large values may slow down convergence. Introduced in [2]_ with a default
Large values may slow down convergence. Introduced in [Hoffman2014]_ with a default
value of 10.
gamma
Controls the speed of convergence of the scheme. The authors of [2]_ recommend
gamma:
Controls the speed of convergence of the scheme. The authors of [Hoffman2014]_ recommend
a value of 0.05.
kappa: float in ]0.5, 1]
Controls the weights of past steps in the current update. The scheme will
quickly forget earlier step for a small value of `kappa`. Introduced
in [2]_, with a recommended value of .75
in [Hoffman2014]_, with a recommended value of .75
target:
Target acceptance rate.
Expand All @@ -102,11 +102,11 @@ def dual_averaging_adaptation(
References
----------
.. [1]: Nesterov, Yurii. "Primal-dual subgradient methods for convex
.. [Nesterov2009] Nesterov, Yurii. "Primal-dual subgradient methods for convex
problems." Mathematical programming 120.1 (2009): 221-259.
.. [2]: Hoffman, Matthew D., and Andrew Gelman. "The No-U-Turn sampler:
adaptively setting path lengths in Hamiltonian Monte Carlo." Journal
of Machine Learning Research 15.1 (2014): 1593-1623.
.. [Hoffman2014] Hoffman, Matthew D., and Andrew Gelman. "The No-U-Turn sampler:
adaptively setting path lengths in Hamiltonian Monte Carlo." Journal
of Machine Learning Research 15.1 (2014): 1593-1623.
"""
da_init, da_update, da_final = optimizers.dual_averaging(t0, gamma, kappa)

Expand Down Expand Up @@ -183,7 +183,7 @@ def find_reasonable_step_size(
value for the step size starting from any value, choosing a good first
value can speed up the convergence. This heuristics doubles and halves the
step size until the acceptance probability of the HMC proposal crosses the
target value.
target value [Hoffman2014]_.
Parameters
----------
Expand All @@ -208,11 +208,12 @@ def find_reasonable_step_size(
float
A reasonable first value for the step size.
Reference
---------
.. [1]: Hoffman, Matthew D., and Andrew Gelman. "The No-U-Turn sampler:
adaptively setting path lengths in Hamiltonian Monte Carlo." Journal
of Machine Learning Research 15.1 (2014): 1593-1623.
References
----------
.. [Hoffman2014] Hoffman, Matthew D., and Andrew Gelman. "The No-U-Turn sampler:
adaptively setting path lengths in Hamiltonian Monte Carlo." Journal
of Machine Learning Research 15.1 (2014): 1593-1623.
"""
fp_limit = jnp.finfo(jax.lax.dtype(initial_step_size))

Expand All @@ -228,9 +229,9 @@ def do_continue(rss_state: ReasonableStepSizeState) -> bool:
occur any performance penalty when calling it repeatedly inside this
function.
Reference
---------
.. [1]: jax.numpy.finfo documentation. https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.finfo.html
References
----------
.. [1] jax.numpy.finfo documentation. https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.finfo.html
"""
_, direction, previous_direction, step_size = rss_state
Expand Down
18 changes: 9 additions & 9 deletions blackjax/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def potential_scale_reduction(
References
----------
.. [1]: https://mc-stan.org/docs/2_27/reference-manual/notation-for-samples-chains-and-draws.html#potential-scale-reduction
.. [2]: Gelman, Andrew, and Donald B. Rubin. (1992) “Inference from Iterative Simulation Using Multiple Sequences.” Statistical Science 7 (4): 457–72.
.. [1] https://mc-stan.org/docs/2_27/reference-manual/notation-for-samples-chains-and-draws.html#potential-scale-reduction
.. [2] Gelman, Andrew, and Donald B. Rubin. (1992) “Inference from Iterative Simulation Using Multiple Sequences.” Statistical Science 7 (4): 457–72.
"""
assert (
Expand Down Expand Up @@ -95,19 +95,19 @@ def effective_sample_size(
.. math:: \\hat{\\tau} = -1 + 2 \\sum_{t'=0}^K \\hat{P}_{t'}
where :math:`M` is the number of chains, :math:`N` the number of draws,
:math:`\\hat{\rho}_t` is the estimated _autocorrelation at lag :math:`t`, and
:math:`K` is the last integer for which :math:`\\hat{P}_{K} = \\hat{\rho}_{2K} +
\\hat{\rho}_{2K+1}` is still positive.
:math:`\\hat{\\rho}_t` is the estimated _autocorrelation at lag :math:`t`, and
:math:`K` is the last integer for which :math:`\\hat{P}_{K} = \\hat{\\rho}_{2K} +
\\hat{\\rho}_{2K+1}` is still positive.
The current implementation is similar to Stan, which uses Geyer's initial monotone sequence
criterion (Geyer, 1992; Geyer, 2011).
References
----------
.. [1]: https://mc-stan.org/docs/2_27/reference-manual/effective-sample-size-section.html
.. [2]: Gelman, Andrew, J. B. Carlin, Hal S. Stern, David B. Dunson, Aki Vehtari, and Donald B. Rubin. (2013). Bayesian Data Analysis. Third Edition. Chapman; Hall/CRC.
.. [3]: Geyer, Charles J. (1992). “Practical Markov Chain Monte Carlo.” Statistical Science, 473–83.
.. [4]: Geyer, Charles J. (2011). “Introduction to Markov Chain Monte Carlo.” In Handbook of Markov Chain Monte Carlo, edited by Steve Brooks, Andrew Gelman, Galin L. Jones, and Xiao-Li Meng, 3–48. Chapman; Hall/CRC.
.. [1] https://mc-stan.org/docs/2_27/reference-manual/effective-sample-size-section.html
.. [2] Gelman, Andrew, J. B. Carlin, Hal S. Stern, David B. Dunson, Aki Vehtari, and Donald B. Rubin. (2013). Bayesian Data Analysis. Third Edition. Chapman; Hall/CRC.
.. [3] Geyer, Charles J. (1992). “Practical Markov Chain Monte Carlo.” Statistical Science, 473–83.
.. [4] Geyer, Charles J. (2011). “Introduction to Markov Chain Monte Carlo.” In Handbook of Markov Chain Monte Carlo, edited by Steve Brooks, Andrew Gelman, Galin L. Jones, and Xiao-Li Meng, 3–48. Chapman; Hall/CRC.
"""
input_shape = input_array.shape
Expand Down
47 changes: 24 additions & 23 deletions blackjax/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,36 @@ def kernel(
) -> Callable:
"""Build an iterative NUTS kernel.
This algorithm is an iteration on the original NUTS algorithm [Hoffman2014]_ with two major differences:
- We do not use slice samplig but multinomial sampling for the proposal [Betancourt2017]_;
- The trajectory expansion is not recursive but iterative [Phan2019]_, [Lao2020]_.
The implementation can seem unusual for those familiar with similar
algorithms. Indeed, we do not conceptualize the trajectory construction as
building a tree. We feel that the tree lingo, inherited from the recursive
version, is unnecessarily complicated and hides the more general concepts
on which the NUTS algorithm is built.
NUTS, in essence, consists in sampling a trajectory by iteratively choosing
a direction at random and integrating in this direction a number of times
that doubles at every step. From this trajectory we continuously sample a
proposal. When the trajectory turns on itself or when we have reached the
maximum trajectory length we return the current proposal.
Parameters
----------
logprob_fb
Log probability function we wish to sample from.
parameters
A NamedTuple that contains the parameters of the kernel to be built.
References
----------
.. [Hoffman2014] Hoffman, Matthew D., and Andrew Gelman. "The No-U-Turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo." J. Mach. Learn. Res. 15.1 (2014): 1593-1623.
.. [Betancourt2017] Betancourt, Michael. "A conceptual introduction to Hamiltonian Monte Carlo." arXiv preprint arXiv:1701.02434 (2017).
.. [Phan2019] Phan, Du, Neeraj Pradhan, and Martin Jankowiak. "Composable effects for flexible and accelerated probabilistic programming in NumPyro." arXiv preprint arXiv:1912.11554 (2019).
.. [Lao2020] Lao, Junpeng, et al. "tfp. mcmc: Modern markov chain monte carlo tools built for modern hardware." arXiv preprint arXiv:2002.01184 (2020).
"""

def potential_fn(x):
Expand Down Expand Up @@ -105,23 +128,7 @@ def iterative_nuts_proposal(
max_num_expansions: int = 10,
divergence_threshold: float = 1000,
) -> Callable:
"""Iterative NUTS algorithm.
This algorithm is an iteration on the original NUTS algorithm [1]_ with two major differences:
- We do not use slice samplig but multinomial sampling for the proposal [2]_;
- The trajectory expansion is not recursive but iterative [3,4]_.
The implementation can seem unusual for those familiar with similar
algorithms. Indeed, we do not conceptualize the trajectory construction as
building a tree. We feel that the tree lingo, inherited from the recursive
version, is unnecessarily complicated and hides the more general concepts
on which the NUTS algorithm is built.
NUTS, in essence, consists in sampling a trajectory by iteratively choosing
a direction at random and integrating in this direction a number of times
that doubles at every step. From this trajectory we continuously sample a
proposal. When the trajectory turns on itself or when we have reached the
maximum trajectory length we return the current proposal.
"""Iterative NUTS proposal.
Parameters
----------
Expand All @@ -142,12 +149,6 @@ def iterative_nuts_proposal(
-------
A kernel that generates a new chain state and information about the transition.
References
----------
.. [1]: Hoffman, Matthew D., and Andrew Gelman. "The No-U-Turn sampler: adaptively setting path lengths in Hamiltonian Monte Carlo." J. Mach. Learn. Res. 15.1 (2014): 1593-1623.
.. [2]: Betancourt, Michael. "A conceptual introduction to Hamiltonian Monte Carlo." arXiv preprint arXiv:1701.02434 (2017).
.. [3]: Phan, Du, Neeraj Pradhan, and Martin Jankowiak. "Composable effects for flexible and accelerated probabilistic programming in NumPyro." arXiv preprint arXiv:1912.11554 (2019).
.. [4]: Lao, Junpeng, et al. "tfp. mcmc: Modern markov chain monte carlo tools built for modern hardware." arXiv preprint arXiv:2002.01184 (2020).
"""
(
new_termination_state,
Expand Down
5 changes: 2 additions & 3 deletions blackjax/stan_warmup.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,11 @@ def stan_warmup(
Schematically:
```
+---------+---+------+------------+------------------------+------+
| fast | s | slow | slow | slow | fast |
+---------+---+------+------------+------------------------+------+
1 2 3 3 3 3
```
|1 |2 |3 |3 |3 |3 |
+---------+---+------+------------+------------------------+------+
Step (1) consists in find a "reasonable" first step size that is used to
initialize the dual averaging scheme. In (2) we initialize the mass matrix
Expand Down
2 changes: 1 addition & 1 deletion blackjax/tempered_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def tempered_smc(
Tempered SMC uses tempering to sample from a distribution given by
:math..
.. math::
p(x) \\propto p_0(x) \\exp(-V(x)) \\mathrm{d}x
where :math:`p_0` is the prior distribution, typically easy to sample from and for which the density
Expand Down
4 changes: 4 additions & 0 deletions blackjax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
import jax.numpy as jnp
import numpy as np

#: JAX or Numpy array
Array = Union[np.ndarray, jnp.ndarray]

#: JAX PyTrees
PyTree = Union[Array, Iterable[Array], Mapping[Any, Array]]
# It is not currently tested but we also support recursive PyTrees.
# Once recursive typing is fully supported (https://github.com/python/mypy/issues/731), we can uncomment the line below.
# PyTree = Union[Array, Iterable["PyTree"], Mapping[Any, "PyTree"]]

#: JAX PRNGKey
PRNGKey = jnp.ndarray
Binary file added docs/_static/blackjax.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 6 additions & 0 deletions docs/_static/custom.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#site-navigation{
background: #E6E7E8;
h1.site-logo {
font-weight: bold;
}
}
86 changes: 86 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
Common Kernels
==============

.. currentmodule:: blackjax

.. autosummary::

hmc
nuts
rmh
tempered_smc

HMC
~~~

.. automodule:: blackjax.hmc
:members: HMCInfo, kernel, new_state

NUTS
~~~~

.. automodule:: blackjax.nuts
:members: NUTSInfo, kernel, new_state

RMH
~~~

.. automodule:: blackjax.rmh
:members:
:undoc-members:

Tempered SMC
~~~~~~~~~~~~

.. automodule:: blackjax.tempered_smc
:members: TemperedSMCState, adaptive_tempered_smc, tempered_smc


Adaptation
==========


Stan full warmup
~~~~~~~~~~~~~~~~

.. currentmodule:: blackjax

.. automodule:: blackjax.stan_warmup
:members: run

Step-size adataptation
~~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: blackjax.adaptation.step_size

.. autofunction:: dual_averaging_adaptation

.. autofunction:: find_reasonable_step_size

Mass matrix adataptation
~~~~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: blackjax.adaptation.mass_matrix

.. autofunction:: mass_matrix_adaptation

Diagnostics
===========

.. currentmodule:: blackjax.diagnostics

.. autosummary::

effective_sample_size
potential_scale_reduction

Effective sample size
~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: effective_sample_size


Potential scale reduction
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autofunction:: potential_scale_reduction
Loading

0 comments on commit 094748a

Please sign in to comment.