From 4d5d67674e4abc1ebb2c0c6619f73a41f44f9c79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Wed, 19 Jan 2022 06:45:57 +0100 Subject: [PATCH] Update index text --- docs/README.md | 1 - docs/api.rst | 4 +- docs/{tutorials.rst => examples.rst} | 6 +- docs/index.rst | 137 +++++++++++++++++++++++++-- requirements-dev.txt | 10 +- 5 files changed, 141 insertions(+), 17 deletions(-) delete mode 120000 docs/README.md rename docs/{tutorials.rst => examples.rst} (83%) diff --git a/docs/README.md b/docs/README.md deleted file mode 120000 index 32d46ee88..000000000 --- a/docs/README.md +++ /dev/null @@ -1 +0,0 @@ -../README.md \ No newline at end of file diff --git a/docs/api.rst b/docs/api.rst index 00cea2f17..adfd6d0b7 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1,5 +1,5 @@ -API -=== +API Reference +============= NUTS ~~~~ diff --git a/docs/tutorials.rst b/docs/examples.rst similarity index 83% rename from docs/tutorials.rst rename to docs/examples.rst index e347e9d35..20b535b8a 100644 --- a/docs/tutorials.rst +++ b/docs/examples.rst @@ -1,8 +1,8 @@ -Tutorials -========= +Examples +======== .. toctree:: - :caption: Tutorials + :caption: Examples notebooks/Introduction.ipynb notebooks/LogisticRegression.ipynb diff --git a/docs/index.rst b/docs/index.rst index 7f8d2ebe0..dc36fa553 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,17 +1,142 @@ -Welcome to blackjax's documentation! -==================================== +Welcome to BlackJAX +=================== + + +BlackJAX is a library of samplers for `JAX `_ that +works on CPU as well as GPU. + +It is *not* a probabilistic programming library. However it integrates really +well with PPLs as long as they can provide a (potentially unnormalized) +log-probability density function compatible with JAX. + +Who should use BlackJAX? +------------------------ + +BlackJAX should appeal to those who: +- Have a logpdf and just need a sampler; +- Need more than a general-purpose sampler; +- Want to sample on GPU; +- Want to build upon robust elementary blocks for their research; +- Are building a probabilistic programming language; +- Want to learn how sampling algorithms work. + +Quickstart +---------- + +Installation +~~~~~~~~~~~~ + +BlackJAX is written in pure Python but depends on XLA via JAX. Since the JAX +installation depends on your CUDA version BlackJAX does not list JAX as a +dependency. If you simply want to use JAX on CPU, install it with: + +.. code-block:: bash + + pip install jax jaxlib + +Follow `these instructions `_ to +install JAX with the relevant hardware acceleration support. + +Then install BlackJAX + +.. code-block:: bash + + pip install blackjax + +### Example + +Let us look at a simple self-contained example sampling with NUTS: + +.. code-block:: python + + import jax + import jax.numpy as jnp + import jax.scipy.stats as stats + import numpy as np + + import blackjax.nuts as nuts + + observed = np.random.normal(10, 20, size=1_000) + def logprob_fn(x): + logpdf = stats.norm.logpdf(observed, x["loc"], x["scale"]) + return jnp.sum(logpdf) + + # Build the kernel + step_size = 1e-3 + inverse_mass_matrix = jnp.array([1., 1.]) + kernel = nuts.kernel(logprob_fn, step_size, inverse_mass_matrix) + kernel = jax.jit(kernel) # try without to see the speedup + + # Initialize the state + initial_position = {"loc": 1., "scale": 2.} + state = nuts.new_state(initial_position, logprob_fn) + + # Iterate + rng_key = jax.random.PRNGKey(0) + for _ in range(1_000): + _, rng_key = jax.random.split(rng_key) + state, _ = kernel(rng_key, state) + +See `this notebook `_ for more examples of how to use the library: how to write inference loops for one or several chains, how to use the Stan warmup, etc. + +Philosophy +---------- + +What is BlackJAX? +~~~~~~~~~~~~~~~~~ + +BlackJAX bridges the gap between "one liner" frameworks and modular, customizable +libraries. + +Users can import the library and interact with robust, well-tested and performant +samplers with a few lines of code. These samplers are aimed at PPL developers, +or people who have a logpdf and just need a sampler that works. + +But the true strength of BlackJAX lies in its internals and how they can be used +to experiment quickly on existing or new sampling schemes. This lower level +exposes the building blocks of inference algorithms: integrators, proposal, +momentum generators, etc and makes it easy to combine them to build new +algorithms. It provides an opportunity to accelerate research on sampling +algorithms by providing robust, performant and reusable code. + +Why BlackJAX? +~~~~~~~~~~~~~ + +Sampling algorithms are too often integrated into PPLs and not decoupled from +the rest of the framework, making them hard to use for people who do not need +the modeling language to build their logpdf. Their implementation is most of +the time monolithic and it is impossible to reuse parts of the algorithm to +build custom kernels. BlackJAX solves both problems. + +How does it work? +~~~~~~~~~~~~~~~~~ + +BlackJAX allows to build arbitrarily complex algorithms because it is built +around a very general pattern. Everything that takes a state and returns a state +is a transition kernel, and is implemented as: + +.. code-block:: python + + new_state, info = kernel(rng_key, state) + +kernels are stateless functions and all follow the same API; state and +information related to the transition are returned separately. They can thus be +easily composed and exchanged. We specialize these kernels by closure instead of +passing parameters. + +Documentation +------------- .. toctree:: :maxdepth: 1 :glob: - :caption: Contents: - README.md + index api - tutorials + examples Indices and tables -================== +------------------ * :ref:`genindex` * :ref:`modindex` diff --git a/requirements-dev.txt b/requirements-dev.txt index eab597199..86b5b835f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -5,13 +5,16 @@ attrs==20.3.0 black==20.8b1 chex>=0.0.8 click==7.1.2 +docutils>=0.17 execnet==1.7.1 flake8==3.8.4 +furo iniconfig==1.1.1 isort==5.6.4 mccabe==0.6.1 mypy==0.790 mypy-extensions==0.4.3 +myst_nb packaging==20.7 pathspec==0.8.1 pluggy==0.13.1 @@ -26,11 +29,8 @@ pytest-forked==1.3.0 pytest-html==3.1.1 pytest-xdist==2.1.0 regex==2020.11.13 +sphinx==4.3.2 +sphinx-autobuild==2021.3.14 toml==0.10.2 typed-ast==1.4.1 typing-extensions==3.7.4.3 -sphinx==4.3.2 -sphinx-autobuild==2021.3.14 -furo -myst_nb -docutils>=0.17