From b01cd180d0e85b5f9dadbf7641e62be039c94191 Mon Sep 17 00:00:00 2001 From: EdanToledo Date: Mon, 17 Jun 2024 10:15:48 +0100 Subject: [PATCH 1/3] fix: linter for readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 2812ef4..e223539 100644 --- a/README.md +++ b/README.md @@ -272,7 +272,7 @@ benchmarks. - 🎶 [Reverb](https://github.com/google-deepmind/reverb): efficient replay buffers used for both local and large-scale distributed RL. - 🍰 [Dopamine](https://github.com/google/dopamine/blob/master/dopamine/replay_memory/): research framework for fast prototyping, providing several core replay buffers. - 🤖 [StableBaselines3](https://stable-baselines3.readthedocs.io/en/master/): suite of reliable RL baselines with its own, easy-to-use replay buffers. - + ### Example Usage Checkout some libraries from the community that utilise flashbax: - 🦁 [Mava](https://github.com/instadeepai/Mava): end-to-end JAX implementations of multi-agent algorithms utilising flashbax. From 57bb6d3c2822df19d63c0ebf7a062267f594f799 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Thu, 4 Jul 2024 17:18:29 +0200 Subject: [PATCH 2/3] first attempt to fix segfault on fbx vault tests. --- flashbax/vault/vault_test.py | 46 ++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/flashbax/vault/vault_test.py b/flashbax/vault/vault_test.py index 6a2bad7..30d7620 100644 --- a/flashbax/vault/vault_test.py +++ b/flashbax/vault/vault_test.py @@ -57,31 +57,31 @@ def test_write_to_vault( fake_transition: FbxTransition, max_length: int, ): - with TemporaryDirectory() as temp_dir_path: - # Get the buffer pure functions - buffer = fbx.make_flat_buffer( - max_length=max_length, - min_length=1, - sample_batch_size=1, - ) - buffer_add = jax.jit(buffer.add, donate_argnums=0) - buffer_state = buffer.init(fake_transition) # Initialise the state + # with TemporaryDirectory() as temp_dir_path: + # Get the buffer pure functions + buffer = fbx.make_flat_buffer( + max_length=max_length, + min_length=1, + sample_batch_size=1, + ) + buffer_add = jax.jit(buffer.add, donate_argnums=0) + buffer_state = buffer.init(fake_transition) # Initialise the state + + # Initialise the vault + v = Vault( + vault_name="test_vault", + experience_structure=buffer_state.experience, + rel_dir="vaults", + ) - # Initialise the vault - v = Vault( - vault_name="test_vault", - experience_structure=buffer_state.experience, - rel_dir=temp_dir_path, + # Add to the vault up to the fbx buffer being full + for i in range(0, max_length): + assert v.vault_index == i + buffer_state = buffer_add( + buffer_state, + fake_transition, ) - - # Add to the vault up to the fbx buffer being full - for i in range(0, max_length): - assert v.vault_index == i - buffer_state = buffer_add( - buffer_state, - fake_transition, - ) - v.write(buffer_state) + v.write(buffer_state) def test_read_from_vault( From 43f7e3334a3b072880b4f96ed4f7b31c90a13327 Mon Sep 17 00:00:00 2001 From: Callum Tilbury Date: Thu, 4 Jul 2024 17:27:20 +0200 Subject: [PATCH 3/3] undoing previous attempt, as pytest is no longer failing? --- flashbax/vault/vault_test.py | 46 ++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/flashbax/vault/vault_test.py b/flashbax/vault/vault_test.py index 30d7620..6a2bad7 100644 --- a/flashbax/vault/vault_test.py +++ b/flashbax/vault/vault_test.py @@ -57,31 +57,31 @@ def test_write_to_vault( fake_transition: FbxTransition, max_length: int, ): - # with TemporaryDirectory() as temp_dir_path: - # Get the buffer pure functions - buffer = fbx.make_flat_buffer( - max_length=max_length, - min_length=1, - sample_batch_size=1, - ) - buffer_add = jax.jit(buffer.add, donate_argnums=0) - buffer_state = buffer.init(fake_transition) # Initialise the state - - # Initialise the vault - v = Vault( - vault_name="test_vault", - experience_structure=buffer_state.experience, - rel_dir="vaults", - ) + with TemporaryDirectory() as temp_dir_path: + # Get the buffer pure functions + buffer = fbx.make_flat_buffer( + max_length=max_length, + min_length=1, + sample_batch_size=1, + ) + buffer_add = jax.jit(buffer.add, donate_argnums=0) + buffer_state = buffer.init(fake_transition) # Initialise the state - # Add to the vault up to the fbx buffer being full - for i in range(0, max_length): - assert v.vault_index == i - buffer_state = buffer_add( - buffer_state, - fake_transition, + # Initialise the vault + v = Vault( + vault_name="test_vault", + experience_structure=buffer_state.experience, + rel_dir=temp_dir_path, ) - v.write(buffer_state) + + # Add to the vault up to the fbx buffer being full + for i in range(0, max_length): + assert v.vault_index == i + buffer_state = buffer_add( + buffer_state, + fake_transition, + ) + v.write(buffer_state) def test_read_from_vault(