Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate from deprecated host_callback to io_callback #651

Merged
merged 3 commits into from
Mar 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 17 additions & 34 deletions blackjax/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""
from fastprogress.fastprogress import progress_bar
from jax import lax
from jax.experimental import host_callback
from jax.experimental import io_callback


def progress_bar_scan(num_samples, print_rate=None):
Expand All @@ -29,55 +29,39 @@ def progress_bar_scan(num_samples, print_rate=None):
else:
print_rate = 1 # if you run the sampler for less than 20 iterations

def _define_bar(arg, transform, device):
def _define_bar(arg):
del arg
progress_bars[0] = progress_bar(range(num_samples))
progress_bars[0].update(0)

def _update_bar(arg, transform, device):
progress_bars[0].update_bar(arg)
def _update_bar(arg):
progress_bars[0].update_bar(arg + 1)

def _close_bar(arg):
del arg
progress_bars[0].on_iter_end()

def _update_progress_bar(iter_num):
"Updates progress bar of a JAX scan or loop"
_ = lax.cond(
iter_num == 0,
lambda _: host_callback.id_tap(
_define_bar, iter_num, result=iter_num, tap_with_device=True
),
lambda _: iter_num,
lambda _: io_callback(_define_bar, None, iter_num),
lambda _: None,
operand=None,
)

_ = lax.cond(
# update every multiple of `print_rate` except at the end
(iter_num % print_rate == 0),
lambda _: host_callback.id_tap(
_update_bar, iter_num, result=iter_num, tap_with_device=True
),
lambda _: iter_num,
(iter_num % print_rate == 0) | (iter_num == (num_samples - 1)),
lambda _: io_callback(_update_bar, None, iter_num),
lambda _: None,
operand=None,
)

_ = lax.cond(
# update by `remainder`
iter_num == num_samples - 1,
lambda _: host_callback.id_tap(
_update_bar, num_samples, result=iter_num, tap_with_device=True
),
lambda _: iter_num,
operand=None,
)

def _close_bar(arg, transform, device):
progress_bars[0].on_iter_end()
print()

def close_bar(result, iter_num):
return lax.cond(
iter_num == num_samples - 1,
lambda _: host_callback.id_tap(
_close_bar, None, result=result, tap_with_device=True
),
lambda _: result,
lambda _: io_callback(_close_bar, None, None),
lambda _: None,
operand=None,
)

Expand All @@ -94,8 +78,7 @@ def wrapper_progress_bar(carry, x):
else:
iter_num = x
_update_progress_bar(iter_num)
result = func(carry, x)
return close_bar(result, iter_num)
return func(carry, x)

return wrapper_progress_bar

Expand Down
Loading