diff --git a/flashbax/vault/vault_test.py b/flashbax/vault/vault_test.py index 30d7620..6a2bad7 100644 --- a/flashbax/vault/vault_test.py +++ b/flashbax/vault/vault_test.py @@ -57,31 +57,31 @@ def test_write_to_vault( fake_transition: FbxTransition, max_length: int, ): - # with TemporaryDirectory() as temp_dir_path: - # Get the buffer pure functions - buffer = fbx.make_flat_buffer( - max_length=max_length, - min_length=1, - sample_batch_size=1, - ) - buffer_add = jax.jit(buffer.add, donate_argnums=0) - buffer_state = buffer.init(fake_transition) # Initialise the state - - # Initialise the vault - v = Vault( - vault_name="test_vault", - experience_structure=buffer_state.experience, - rel_dir="vaults", - ) + with TemporaryDirectory() as temp_dir_path: + # Get the buffer pure functions + buffer = fbx.make_flat_buffer( + max_length=max_length, + min_length=1, + sample_batch_size=1, + ) + buffer_add = jax.jit(buffer.add, donate_argnums=0) + buffer_state = buffer.init(fake_transition) # Initialise the state - # Add to the vault up to the fbx buffer being full - for i in range(0, max_length): - assert v.vault_index == i - buffer_state = buffer_add( - buffer_state, - fake_transition, + # Initialise the vault + v = Vault( + vault_name="test_vault", + experience_structure=buffer_state.experience, + rel_dir=temp_dir_path, ) - v.write(buffer_state) + + # Add to the vault up to the fbx buffer being full + for i in range(0, max_length): + assert v.vault_index == i + buffer_state = buffer_add( + buffer_state, + fake_transition, + ) + v.write(buffer_state) def test_read_from_vault(