From 8f00d32a4023d2e94f57dace1c89072ec092192a Mon Sep 17 00:00:00 2001 From: johannahaffner <38662446+johannahaffner@users.noreply.github.com> Date: Wed, 31 Jul 2024 15:50:47 +0200 Subject: [PATCH] Update index.md The jitted step remained unused, leading to the example running with an uncompiled nuts.step. Changing this reduces the execution time by a factor of 30 on my system and showcases blackjax' speed. --- docs/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.md b/docs/index.md index edc02631c..fca4787c4 100644 --- a/docs/index.md +++ b/docs/index.md @@ -41,7 +41,7 @@ rng_key = jax.random.key(0) step = jax.jit(nuts.step) for i in range(1_000): nuts_key = jax.random.fold_in(rng_key, i) - state, _ = nuts.step(nuts_key, state) + state, _ = step(nuts_key, state) ``` :::{note}