Skip to content

Commit

Permalink
Migrate from deprecated host_callback to io_callback
Browse files Browse the repository at this point in the history
  • Loading branch information
junpenglao committed Mar 27, 2024
1 parent 2ccdfb0 commit 674af39
Showing 1 changed file with 14 additions and 20 deletions.
34 changes: 14 additions & 20 deletions blackjax/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""
from fastprogress.fastprogress import progress_bar
from jax import lax
from jax.experimental import host_callback
from jax.experimental import io_callback


def progress_bar_scan(num_samples, print_rate=None):
Expand All @@ -29,55 +29,49 @@ def progress_bar_scan(num_samples, print_rate=None):
else:
print_rate = 1 # if you run the sampler for less than 20 iterations

def _define_bar(arg, transform, device):
def _define_bar(arg):
del arg
progress_bars[0] = progress_bar(range(num_samples))
progress_bars[0].update(0)

def _update_bar(arg, transform, device):
def _update_bar(arg):
progress_bars[0].update_bar(arg)

def _update_progress_bar(iter_num):
"Updates progress bar of a JAX scan or loop"
_ = lax.cond(
iter_num == 0,
lambda _: host_callback.id_tap(
_define_bar, iter_num, result=iter_num, tap_with_device=True
),
lambda _: iter_num,
lambda _: io_callback(_define_bar, None, iter_num),
lambda _: None,
operand=None,
)

_ = lax.cond(
# update every multiple of `print_rate` except at the end
(iter_num % print_rate == 0),
lambda _: host_callback.id_tap(
_update_bar, iter_num, result=iter_num, tap_with_device=True
),
lambda _: iter_num,
lambda _: io_callback(_update_bar, None, iter_num),
lambda _: None,
operand=None,
)

_ = lax.cond(
# update by `remainder`
iter_num == num_samples - 1,
lambda _: host_callback.id_tap(
_update_bar, num_samples, result=iter_num, tap_with_device=True
),
lambda _: iter_num,
lambda _: io_callback(_update_bar, None, iter_num),
lambda _: None,
operand=None,
)

def _close_bar(arg, transform, device):
def _close_bar(arg):
del arg
progress_bars[0].on_iter_end()
print()

def close_bar(result, iter_num):
return lax.cond(
iter_num == num_samples - 1,
lambda _: host_callback.id_tap(
_close_bar, None, result=result, tap_with_device=True
),
lambda _: result,
lambda _: io_callback(_close_bar, None, None),
lambda _: None,
operand=None,
)

Expand Down

0 comments on commit 674af39

Please sign in to comment.