-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathreplay_buffer.py
49 lines (42 loc) · 1.72 KB
/
replay_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import numpy as np
from typing_extensions import Self
from utils import DEVICE, D4RLDatasetKeys, StateUtils
class ReplayBuffer(object):
def __init__(
self,
max_steps: int,
state_dim: int,
action_dim: int,
norm_epsilon: float,
):
self.state = np.zeros((max_steps, state_dim))
self.action = np.zeros((max_steps, action_dim))
self.next_state = np.zeros((max_steps, state_dim))
self.reward = np.zeros((max_steps, 1)).reshape(-1, 1)
self.not_done = np.zeros((max_steps, 1)).reshape(-1, 1)
self.norm_epsilon = norm_epsilon
self.size = 0
self.mean = 0
self.std = 0
def random_sample(self, batch_size: int):
rand_index = np.random.randint(0, self.size, size=batch_size)
return (
torch.FloatTensor(self.state[rand_index]).to(DEVICE),
torch.FloatTensor(self.action[rand_index]).to(DEVICE),
torch.FloatTensor(self.next_state[rand_index]).to(DEVICE),
torch.FloatTensor(self.reward[rand_index]).to(DEVICE),
torch.FloatTensor(self.not_done[rand_index]).to(DEVICE),
)
def init_states_from_D4RL_dataset(self, dataset) -> Self:
self.state = dataset[D4RLDatasetKeys.STATE]
self.action = dataset[D4RLDatasetKeys.ACTIONS]
self.next_state = dataset[D4RLDatasetKeys.NEXT_STATE]
self.reward = dataset[D4RLDatasetKeys.REWARD].reshape(-1, 1)
self.not_done = 1.0 - dataset[D4RLDatasetKeys.NOT_DONE].reshape(-1, 1)
self.size = self.state.shape[0]
# Always normalize states
(self.mean, self.std) = StateUtils.normalize(
self.state, self.next_state
)
return self