Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to load saved buffer properly? #47

Open
4ku opened this issue Nov 26, 2024 · 4 comments
Open

How to load saved buffer properly? #47

4ku opened this issue Nov 26, 2024 · 4 comments

Comments

@4ku
Copy link

4ku commented Nov 26, 2024

I tried:

class ReplayBufferDataStore():
    def __init__(
        self,
        env: gym.Env,
        capacity: int,
        sample_batch_size: int = 32,
        priority_exponent: float = 0.8,
        device: str = "gpu",
        name: str = "replay_buffer",
        checkpoint_path: str = None,
    ):
        self.sample_batch_size = sample_batch_size
        self.priority_exponent = priority_exponent
        self.device = jax.devices(device)[0]

        self.buffer = fbx.make_prioritised_flat_buffer(
            max_length=capacity,
            min_length=sample_batch_size,
            sample_batch_size=sample_batch_size,
            add_sequences=True,
            add_batch_size=None,
            priority_exponent=priority_exponent,
            device=device,
        )

        # Preprocess the transition once to avoid redundant transformations
        single_transition = self._initialize_single_transition(env)
        self.state = self.buffer.init(single_transition)
        self.state = jax.device_put(self.state, device=self.device)

        self.vault = Vault(
            vault_name=name,
            experience_structure=self.state.experience,
            rel_dir=os.path.join(os.path.dirname(checkpoint_path), "vaults"),
        )
...

    def save(self):
        self.vault.write(self.state)

    def load(self, vault_path: str):
        vault_name = vault_path.split("/")[-2]
        vault_uid = vault_path.split("/")[-1]
        vault_path = os.path.dirname(os.path.dirname(vault_path))
        vault = Vault(
            vault_name=vault_name,
            experience_structure=self.state.experience,
            rel_dir=vault_path,
            vault_uid=vault_uid,
        )
        state = vault.read()

        loaded_experience = frozen_dict.freeze(state.experience)
        self.state = _insert(self.buffer, self.state, loaded_experience)

experience_structure of loaded state doesn't match with self.state. For example, if there are 500 transitions stored in the buffer via vault, the loaded state size will be 500, but I initialized the buffer with size 100_000.

@lbeyers
Copy link

lbeyers commented Jan 7, 2025

Thanks for the question! The proper way to load experience from a Vault into a buffer state is to use the "buffer.add" function. For example, this notebook uses buffer_add = jax.jit(buffer.add, donate_argnums=0) buffer_state = buffer_add(buffer_state, new_experience)

In your case, "new_experience" would be the experience loaded from the Vault. Please let me know if you have further questions!

@callumtilbury
Copy link
Contributor

callumtilbury commented Jan 7, 2025

Hey @4ku & @lbeyers! Sorry for not taking a look at this sooner :)

As Louise mentioned, you can load the Vault state using an add function, but this is actually inefficient—the read Vault state is exactly compatible with a normal flashbax state. So we don't need to load the vault state, and then use add. We can directly use the loaded state with our usual fbx functions.

It seems that the problem above is actually the number of timesteps in the loaded state—is that correct, @4ku?

If so: indeed, Vault writes to disk the timesteps up to the current_index. e.g. If we haven't added any timesteps when we write to the vault, and then we load in that vault, the state will be of size (B, 0, E). i.e. zero timesteps.

Previously, I was thinking about having this functionality though—perhaps we know how many timesteps we want in the flashbax buffer state, even if the vault is smaller than that. I think it should be a pretty simple change—e.g. we could use the timesteps parameter:

# If time steps are provided, we read the last `timesteps` count of elements
elif timesteps is not None:
read_interval = (self.vault_index - timesteps, self.vault_index)

Currently, if you ask for more timesteps than available (e.g. as in the original post, I ask for a buffer size of 100_000 but the vault only has 500 transitions), it breaks—oops 😂 But this could easily be fixed, and I think is a nice way to use the current API to achieve the desired functionality.

Haven't had a chance to give this a proper look but let me know if I have understood you correctly. Happy to draft something when I get a moment, probably sometime this week.

Thanks!

@4ku
Copy link
Author

4ku commented Jan 9, 2025

@callumtilbury Yeah, you are right. The actual problem is a number of timesteps in the loaded state.
So it will be great if I can do:

state = vault.read(timesteps=100_000)

And state will have size 100_000 with first 500 transitions from vault.
Look forward to your fix! Thank you

@4ku
Copy link
Author

4ku commented Jan 9, 2025

Also I think current_index shouldn't be 0 (or here should be extra parameter). I want to continue to add transitions to the end of loaded state.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants