Skip to content

Commit

Permalink
Update index text
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jan 19, 2022
1 parent 5d40f2c commit 4d5d676
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 17 deletions.
1 change: 0 additions & 1 deletion docs/README.md

This file was deleted.

4 changes: 2 additions & 2 deletions docs/api.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
API
===
API Reference
=============

NUTS
~~~~
Expand Down
6 changes: 3 additions & 3 deletions docs/tutorials.rst → docs/examples.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Tutorials
=========
Examples
========

.. toctree::
:caption: Tutorials
:caption: Examples

notebooks/Introduction.ipynb
notebooks/LogisticRegression.ipynb
Expand Down
137 changes: 131 additions & 6 deletions docs/index.rst
Original file line number Diff line number Diff line change
@@ -1,17 +1,142 @@
Welcome to blackjax's documentation!
====================================
Welcome to BlackJAX
===================


BlackJAX is a library of samplers for `JAX <https://github.com/google/jax>`_ that
works on CPU as well as GPU.

It is *not* a probabilistic programming library. However it integrates really
well with PPLs as long as they can provide a (potentially unnormalized)
log-probability density function compatible with JAX.

Who should use BlackJAX?
------------------------

BlackJAX should appeal to those who:
- Have a logpdf and just need a sampler;
- Need more than a general-purpose sampler;
- Want to sample on GPU;
- Want to build upon robust elementary blocks for their research;
- Are building a probabilistic programming language;
- Want to learn how sampling algorithms work.

Quickstart
----------

Installation
~~~~~~~~~~~~

BlackJAX is written in pure Python but depends on XLA via JAX. Since the JAX
installation depends on your CUDA version BlackJAX does not list JAX as a
dependency. If you simply want to use JAX on CPU, install it with:

.. code-block:: bash
pip install jax jaxlib
Follow `these instructions <https://github.com/google/jax#installation>`_ to
install JAX with the relevant hardware acceleration support.

Then install BlackJAX

.. code-block:: bash
pip install blackjax
### Example

Let us look at a simple self-contained example sampling with NUTS:

.. code-block:: python
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import numpy as np
import blackjax.nuts as nuts
observed = np.random.normal(10, 20, size=1_000)
def logprob_fn(x):
logpdf = stats.norm.logpdf(observed, x["loc"], x["scale"])
return jnp.sum(logpdf)
# Build the kernel
step_size = 1e-3
inverse_mass_matrix = jnp.array([1., 1.])
kernel = nuts.kernel(logprob_fn, step_size, inverse_mass_matrix)
kernel = jax.jit(kernel) # try without to see the speedup
# Initialize the state
initial_position = {"loc": 1., "scale": 2.}
state = nuts.new_state(initial_position, logprob_fn)
# Iterate
rng_key = jax.random.PRNGKey(0)
for _ in range(1_000):
_, rng_key = jax.random.split(rng_key)
state, _ = kernel(rng_key, state)
See `this notebook <https://github.com/blackjax-devs/blackjax/blob/master/notebooks/Introduction.ipynb>`_ for more examples of how to use the library: how to write inference loops for one or several chains, how to use the Stan warmup, etc.

Philosophy
----------

What is BlackJAX?
~~~~~~~~~~~~~~~~~

BlackJAX bridges the gap between "one liner" frameworks and modular, customizable
libraries.

Users can import the library and interact with robust, well-tested and performant
samplers with a few lines of code. These samplers are aimed at PPL developers,
or people who have a logpdf and just need a sampler that works.

But the true strength of BlackJAX lies in its internals and how they can be used
to experiment quickly on existing or new sampling schemes. This lower level
exposes the building blocks of inference algorithms: integrators, proposal,
momentum generators, etc and makes it easy to combine them to build new
algorithms. It provides an opportunity to accelerate research on sampling
algorithms by providing robust, performant and reusable code.

Why BlackJAX?
~~~~~~~~~~~~~

Sampling algorithms are too often integrated into PPLs and not decoupled from
the rest of the framework, making them hard to use for people who do not need
the modeling language to build their logpdf. Their implementation is most of
the time monolithic and it is impossible to reuse parts of the algorithm to
build custom kernels. BlackJAX solves both problems.

How does it work?
~~~~~~~~~~~~~~~~~

BlackJAX allows to build arbitrarily complex algorithms because it is built
around a very general pattern. Everything that takes a state and returns a state
is a transition kernel, and is implemented as:

.. code-block:: python
new_state, info = kernel(rng_key, state)
kernels are stateless functions and all follow the same API; state and
information related to the transition are returned separately. They can thus be
easily composed and exchanged. We specialize these kernels by closure instead of
passing parameters.

Documentation
-------------

.. toctree::
:maxdepth: 1
:glob:
:caption: Contents:

README.md
index
api
tutorials
examples

Indices and tables
==================
------------------

* :ref:`genindex`
* :ref:`modindex`
Expand Down
10 changes: 5 additions & 5 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@ attrs==20.3.0
black==20.8b1
chex>=0.0.8
click==7.1.2
docutils>=0.17
execnet==1.7.1
flake8==3.8.4
furo
iniconfig==1.1.1
isort==5.6.4
mccabe==0.6.1
mypy==0.790
mypy-extensions==0.4.3
myst_nb
packaging==20.7
pathspec==0.8.1
pluggy==0.13.1
Expand All @@ -26,11 +29,8 @@ pytest-forked==1.3.0
pytest-html==3.1.1
pytest-xdist==2.1.0
regex==2020.11.13
sphinx==4.3.2
sphinx-autobuild==2021.3.14
toml==0.10.2
typed-ast==1.4.1
typing-extensions==3.7.4.3
sphinx==4.3.2
sphinx-autobuild==2021.3.14
furo
myst_nb
docutils>=0.17

0 comments on commit 4d5d676

Please sign in to comment.