Skip to content

Commit

Permalink
Migrate progress bar from fastprogress to tqdm, and multiple chain su…
Browse files Browse the repository at this point in the history
…pport
  • Loading branch information
zaxtax committed Apr 1, 2024
1 parent 7cf4f9d commit 4646d06
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 27 deletions.
66 changes: 40 additions & 26 deletions blackjax/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,70 +14,84 @@
"""Progress bar decorators for use with step functions.
Adapted from Jeremie Coullon's blog post :cite:p:`progress_bar`.
"""
from fastprogress.fastprogress import progress_bar
import tqdm
from tqdm.auto import tqdm as tqdm_auto
import jax
from jax import lax
from jax.experimental import io_callback


def progress_bar_scan(num_samples, print_rate=None):
"Progress bar for a JAX scan"
progress_bars = {}
def progress_bar_scan(num_samples, num_chains=1, print_rate=None):
"""Factory that builds a progress bar decorator along
with the `set_tqdm_description` and `close_tqdm` functions
"""

if print_rate is None:
if num_samples > 20:
print_rate = int(num_samples / 20)
else:
print_rate = 1 # if you run the sampler for less than 20 iterations

def _define_bar(arg):
del arg
progress_bars[0] = progress_bar(range(num_samples))
progress_bars[0].update(0)
remainder = num_samples % print_rate

def _update_bar(arg):
progress_bars[0].update_bar(arg + 1)
tqdm_bars = {}
for chain in range(num_chains):
tqdm_bars[chain] = tqdm_auto(range(num_samples), position=chain)
tqdm_bars[chain].set_description("Compiling.. ", refresh=True)

def _close_bar(arg):
del arg
progress_bars[0].on_iter_end()
def _update_tqdm(arg, chain):
chain = int(chain)
tqdm_bars[chain].set_description(f"Running chain {chain}", refresh=False)
tqdm_bars[chain].update(arg)

def _close_tqdm(arg, chain):
chain = int(chain)
tqdm_bars[chain].update(arg)
tqdm_bars[chain].close()

def _update_progress_bar(iter_num, chain):
"""Updates tqdm progress bar of a JAX loop only if the iteration number is a multiple of the print_rate
Usage: carry = progress_bar((iter_num, print_rate), carry)
"""

def _update_progress_bar(iter_num):
"Updates progress bar of a JAX scan or loop"
_ = lax.cond(
iter_num == 0,
lambda _: io_callback(_define_bar, None, iter_num),
lambda _: jax.debug.callback(_update_tqdm, iter_num, chain),
lambda _: None,
operand=None,
)

_ = lax.cond(
# update every multiple of `print_rate` except at the end
(iter_num % print_rate == 0) | (iter_num == (num_samples - 1)),
lambda _: io_callback(_update_bar, None, iter_num),
(iter_num % print_rate) == 0,
lambda _: jax.debug.callback(_update_tqdm, print_rate, chain),
lambda _: None,
operand=None,
)

_ = lax.cond(
iter_num == num_samples - 1,
lambda _: io_callback(_close_bar, None, None),
lambda _: jax.debug.callback(_close_tqdm, remainder, chain),
lambda _: None,
operand=None,
)

def _progress_bar_scan(func):
"""Decorator that adds a progress bar to `body_fun` used in `lax.scan`.
Note that `body_fun` must either be looping over `np.arange(num_samples)`,
or be looping over a tuple who's first element is `np.arange(num_samples)`
looping over a tuple whose elements are `np.arange(num_samples), and a
chain id defined as `chain * np.ones(num_samples)`, or be looping over a
tuple who's first element and second elements include iter_num and chain.
This means that `iter_num` is the current iteration number
"""

def wrapper_progress_bar(carry, x):
if type(x) is tuple:
iter_num, *_ = x
if num_chains > 1:
iter_num, chain, *_ = x
else:
iter_num, *_ = x
chain = 0
else:
iter_num = x
_update_progress_bar(iter_num)
chain = 0
_update_progress_bar(iter_num, chain)
return func(carry, x)

return wrapper_progress_bar
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ classifiers = [
"Topic :: Scientific/Engineering :: Mathematics",
]
dependencies = [
"fastprogress>=1.0.0",
"jax>=0.4.16",
"jaxlib>=0.4.16",
"jaxopt>=0.8",
"optax>=0.1.7",
"tqdm",
"typing-extensions>=4.4.0",
]
dynamic = ["version"]
Expand Down

0 comments on commit 4646d06

Please sign in to comment.