diff --git a/flashbax/buffers/prioritised_trajectory_buffer.py b/flashbax/buffers/prioritised_trajectory_buffer.py index d93d6b1..2337bf8 100644 --- a/flashbax/buffers/prioritised_trajectory_buffer.py +++ b/flashbax/buffers/prioritised_trajectory_buffer.py @@ -798,6 +798,7 @@ def make_prioritised_trajectory_buffer( if max_size is not None: max_length_time_axis = max_size // add_batch_size + assert max_length_time_axis is not None init_fn = functools.partial( prioritised_init, add_batch_size=add_batch_size, diff --git a/flashbax/buffers/trajectory_buffer.py b/flashbax/buffers/trajectory_buffer.py index a06ad45..73b33b6 100644 --- a/flashbax/buffers/trajectory_buffer.py +++ b/flashbax/buffers/trajectory_buffer.py @@ -585,6 +585,7 @@ def make_trajectory_buffer( if max_size is not None: max_length_time_axis = max_size // add_batch_size + assert max_length_time_axis is not None init_fn = functools.partial( init, add_batch_size=add_batch_size,