From 1352bfa06494f45174f5d3498e0795d0d31be77c Mon Sep 17 00:00:00 2001
From: Mick van Gelderen <mickvangelderen@gmail.com>
Date: Tue, 10 Dec 2024 05:38:04 -0800
Subject: [PATCH] Pass max_length_time_axis instead of max_size  (#43)

* Treat warnings as errors

* Pass max_length_time_axis instead of max_size

Makes it so that the warning:

```
Setting max_size dynamically sets the `max_length_time_axis` to be `max_size`//`add_batch_size = .*`
```

will no longer be triggered by legitimate use of `create_flat_buffer` and
`make_prioritised_flat_buffer`.

---------

Co-authored-by: Simon Du Toit <90381208+SimonDuToit@users.noreply.github.com>
---
 flashbax/buffers/flat_buffer.py             | 28 ++++++------------
 flashbax/buffers/prioritised_flat_buffer.py | 32 +++++++--------------
 pyproject.toml                              |  1 -
 3 files changed, 20 insertions(+), 41 deletions(-)

diff --git a/flashbax/buffers/flat_buffer.py b/flashbax/buffers/flat_buffer.py
index 41305a7..409f022 100644
--- a/flashbax/buffers/flat_buffer.py
+++ b/flashbax/buffers/flat_buffer.py
@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import warnings
 from typing import TYPE_CHECKING, Generic, Optional
 
 from chex import PRNGKey
@@ -113,24 +112,15 @@ def create_flat_buffer(
         add_batch_size=add_batch_size,
     )
 
-    with warnings.catch_warnings():
-        warnings.filterwarnings(
-            "ignore",
-            message="Setting max_size dynamically sets the `max_length_time_axis` to "
-            f"be `max_size`//`add_batch_size = {max_length // add_batch_size}`."
-            "This allows one to control exactly how many transitions are stored in the buffer."
-            "Note that this overrides the `max_length_time_axis` argument.",
-        )
-
-        buffer = make_trajectory_buffer(
-            max_length_time_axis=None,  # Unused because max_size is specified
-            min_length_time_axis=min_length // add_batch_size + 1,
-            add_batch_size=add_batch_size,
-            sample_batch_size=sample_batch_size,
-            sample_sequence_length=2,
-            period=1,
-            max_size=max_length,
-        )
+    buffer = make_trajectory_buffer(
+        max_length_time_axis=max_length // add_batch_size,
+        min_length_time_axis=min_length // add_batch_size + 1,
+        add_batch_size=add_batch_size,
+        sample_batch_size=sample_batch_size,
+        sample_sequence_length=2,
+        period=1,
+        max_size=None,
+    )
 
     add_fn = buffer.add
 
diff --git a/flashbax/buffers/prioritised_flat_buffer.py b/flashbax/buffers/prioritised_flat_buffer.py
index a4f6a1d..274d69c 100644
--- a/flashbax/buffers/prioritised_flat_buffer.py
+++ b/flashbax/buffers/prioritised_flat_buffer.py
@@ -12,7 +12,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import warnings
 from typing import TYPE_CHECKING, Optional
 
 from chex import PRNGKey
@@ -100,26 +99,17 @@ def make_prioritised_flat_buffer(
     if not validate_device(device):
         device = "cpu"
 
-    with warnings.catch_warnings():
-        warnings.filterwarnings(
-            "ignore",
-            message="Setting max_size dynamically sets the `max_length_time_axis` to "
-            f"be `max_size`//`add_batch_size = {max_length // add_batch_size}`."
-            "This allows one to control exactly how many transitions are stored in the buffer."
-            "Note that this overrides the `max_length_time_axis` argument.",
-        )
-
-        buffer = make_prioritised_trajectory_buffer(
-            max_length_time_axis=None,  # Unused because max_size is specified
-            min_length_time_axis=min_length // add_batch_size + 1,
-            add_batch_size=add_batch_size,
-            sample_batch_size=sample_batch_size,
-            sample_sequence_length=2,
-            period=1,
-            max_size=max_length,
-            priority_exponent=priority_exponent,
-            device=device,
-        )
+    buffer = make_prioritised_trajectory_buffer(
+        max_length_time_axis=max_length // add_batch_size,
+        min_length_time_axis=min_length // add_batch_size + 1,
+        add_batch_size=add_batch_size,
+        sample_batch_size=sample_batch_size,
+        sample_sequence_length=2,
+        period=1,
+        max_size=None,
+        priority_exponent=priority_exponent,
+        device=device,
+    )
 
     add_fn = buffer.add
 
diff --git a/pyproject.toml b/pyproject.toml
index 29fefe7..5a98cb2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -16,7 +16,6 @@ filterwarnings = [
     "error",
     "ignore:`sample_sequence_length` greater than `min_length_time_axis`:UserWarning:flashbax",
     "ignore:Setting period greater than sample_sequence_length will result in no overlap betweentrajectories:UserWarning:flashbax",
-    "ignore:Setting max_size dynamically sets the `max_length_time_axis` to be `max_size`//`add_batch_size = .*`:UserWarning:flashbax",
     "ignore:jax.tree_map is deprecated:DeprecationWarning:flashbax",
 ]