-
-
Notifications
You must be signed in to change notification settings - Fork 48
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #35 from accuracy-maker/master
add sumtree,PER and PESR
- Loading branch information
Showing
6 changed files
with
476 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import pytest | ||
import random | ||
import torch | ||
from zeta.rl.PrioritizedReplayBuffer import PrioritizedReplayBuffer, SumTree # Replace 'your_module' with the actual module where classes are defined | ||
|
||
@pytest.fixture | ||
def replay_buffer(): | ||
state_size = 4 | ||
action_size = 2 | ||
buffer_size = 100 | ||
device = torch.device("cpu") | ||
return PrioritizedReplayBuffer(state_size, action_size, buffer_size, device) | ||
|
||
def test_initialization(replay_buffer): | ||
assert replay_buffer.eps == 1e-2 | ||
assert replay_buffer.alpha == 0.1 | ||
assert replay_buffer.beta == 0.1 | ||
assert replay_buffer.max_priority == 1.0 | ||
assert replay_buffer.count == 0 | ||
assert replay_buffer.real_size == 0 | ||
assert replay_buffer.size == 100 | ||
assert replay_buffer.device == torch.device("cpu") | ||
|
||
def test_add(replay_buffer): | ||
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False) | ||
replay_buffer.add(transition) | ||
assert replay_buffer.count == 1 | ||
assert replay_buffer.real_size == 1 | ||
|
||
def test_sample(replay_buffer): | ||
for i in range(10): | ||
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False) | ||
replay_buffer.add(transition) | ||
|
||
batch, weights, tree_idxs = replay_buffer.sample(5) | ||
assert len(batch) == 5 | ||
assert len(weights) == 5 | ||
assert len(tree_idxs) == 5 | ||
|
||
def test_update_priorities(replay_buffer): | ||
for i in range(10): | ||
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False) | ||
replay_buffer.add(transition) | ||
|
||
batch, weights, tree_idxs = replay_buffer.sample(5) | ||
new_priorities = torch.rand(5) | ||
replay_buffer.update_priorities(tree_idxs, new_priorities) | ||
|
||
def test_sample_with_invalid_batch_size(replay_buffer): | ||
with pytest.raises(AssertionError): | ||
replay_buffer.sample(101) | ||
|
||
def test_add_with_max_size(replay_buffer): | ||
for i in range(100): | ||
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False) | ||
replay_buffer.add(transition) | ||
|
||
assert replay_buffer.count == 0 | ||
assert replay_buffer.real_size == 100 | ||
|
||
# Additional tests for edge cases, exceptions, and more scenarios can be added as needed. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import pytest | ||
import random | ||
import torch | ||
from zeta.rl.PrioritizedSequenceReplayBuffer import PrioritizedSequenceReplayBuffer, SumTree # Replace 'your_module' with the actual module where classes are defined | ||
|
||
@pytest.fixture | ||
def replay_buffer(): | ||
state_size = 4 | ||
action_size = 2 | ||
buffer_size = 100 | ||
device = torch.device("cpu") | ||
return PrioritizedSequenceReplayBuffer(state_size, action_size, buffer_size, device) | ||
|
||
def test_initialization(replay_buffer): | ||
assert replay_buffer.eps == 1e-5 | ||
assert replay_buffer.alpha == 0.1 | ||
assert replay_buffer.beta == 0.1 | ||
assert replay_buffer.max_priority == 1.0 | ||
assert replay_buffer.decay_window == 5 | ||
assert replay_buffer.decay_coff == 0.4 | ||
assert replay_buffer.pre_priority == 0.7 | ||
assert replay_buffer.count == 0 | ||
assert replay_buffer.real_size == 0 | ||
assert replay_buffer.size == 100 | ||
assert replay_buffer.device == torch.device("cpu") | ||
|
||
def test_add(replay_buffer): | ||
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False) | ||
replay_buffer.add(transition) | ||
assert replay_buffer.count == 1 | ||
assert replay_buffer.real_size == 1 | ||
|
||
def test_sample(replay_buffer): | ||
for i in range(10): | ||
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False) | ||
replay_buffer.add(transition) | ||
|
||
batch, weights, tree_idxs = replay_buffer.sample(5) | ||
assert len(batch) == 5 | ||
assert len(weights) == 5 | ||
assert len(tree_idxs) == 5 | ||
|
||
def test_update_priorities(replay_buffer): | ||
for i in range(10): | ||
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False) | ||
replay_buffer.add(transition) | ||
|
||
batch, weights, tree_idxs = replay_buffer.sample(5) | ||
new_priorities = torch.rand(5) | ||
replay_buffer.update_priorities(tree_idxs, new_priorities) | ||
|
||
def test_sample_with_invalid_batch_size(replay_buffer): | ||
with pytest.raises(AssertionError): | ||
replay_buffer.sample(101) | ||
|
||
def test_add_with_max_size(replay_buffer): | ||
for i in range(100): | ||
transition = (torch.rand(4), torch.rand(2), 1.0, torch.rand(4), False) | ||
replay_buffer.add(transition) | ||
|
||
assert replay_buffer.count == 0 | ||
assert replay_buffer.real_size == 100 | ||
|
||
# Additional tests for edge cases, exceptions, and more scenarios can be added as needed. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import pytest | ||
from zeta.rl.sumtree import SumTree # Replace 'your_module' with the actual module where SumTree is defined | ||
|
||
# Fixture for initializing SumTree instances with a given size | ||
@pytest.fixture | ||
def sum_tree(): | ||
size = 10 # You can change the size as needed | ||
return SumTree(size) | ||
|
||
# Basic tests | ||
def test_initialization(sum_tree): | ||
assert sum_tree.size == 10 | ||
assert sum_tree.count == 0 | ||
assert sum_tree.real_size == 0 | ||
assert sum_tree.total == 0 | ||
|
||
def test_update_and_get(sum_tree): | ||
sum_tree.add(5, "data1") | ||
assert sum_tree.total == 5 | ||
data_idx, priority, data = sum_tree.get(5) | ||
assert data_idx == 0 | ||
assert priority == 5 | ||
assert data == "data1" | ||
|
||
def test_add_overflow(sum_tree): | ||
for i in range(15): | ||
sum_tree.add(i, f"data{i}") | ||
assert sum_tree.count == 5 | ||
assert sum_tree.real_size == 10 | ||
|
||
# Parameterized testing for various scenarios | ||
@pytest.mark.parametrize("values, expected_total", [ | ||
([1, 2, 3, 4, 5], 15), | ||
([10, 20, 30, 40, 50], 150), | ||
]) | ||
def test_multiple_updates(sum_tree, values, expected_total): | ||
for value in values: | ||
sum_tree.add(value, None) | ||
assert sum_tree.total == expected_total | ||
|
||
# Exception testing | ||
def test_get_with_invalid_cumsum(sum_tree): | ||
with pytest.raises(AssertionError): | ||
sum_tree.get(20) | ||
|
||
# More tests for specific methods | ||
def test_get_priority(sum_tree): | ||
sum_tree.add(10, "data1") | ||
priority = sum_tree.get_priority(0) | ||
assert priority == 10 | ||
|
||
def test_repr(sum_tree): | ||
expected_repr = f"SumTree(nodes={sum_tree.nodes}, data={sum_tree.data})" | ||
assert repr(sum_tree) == expected_repr | ||
|
||
# More test cases can be added as needed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from sumtree import SumTree | ||
import torch | ||
import random | ||
|
||
class PrioritizedReplayBuffer: | ||
def __init__(self, state_size, action_size, buffer_size, device, eps=1e-2, alpha=0.1, beta=0.1): | ||
self.tree = SumTree(size=buffer_size) | ||
|
||
|
||
self.eps = eps | ||
self.alpha = alpha | ||
self.beta = beta | ||
self.max_priority = 1. | ||
|
||
|
||
self.state = torch.empty(buffer_size, state_size, dtype=torch.float) | ||
self.action = torch.empty(buffer_size, action_size, dtype=torch.float) | ||
self.reward = torch.empty(buffer_size, dtype=torch.float) | ||
self.next_state = torch.empty(buffer_size, state_size, dtype=torch.float) | ||
self.done = torch.empty(buffer_size, dtype=torch.uint8) | ||
|
||
self.count = 0 | ||
self.real_size = 0 | ||
self.size = buffer_size | ||
|
||
# device | ||
self.device = device | ||
|
||
def add(self, transition): | ||
state, action, reward, next_state, done = transition | ||
|
||
|
||
self.tree.add(self.max_priority, self.count) | ||
|
||
self.state[self.count] = torch.as_tensor(state) | ||
self.action[self.count] = torch.as_tensor(action) | ||
self.reward[self.count] = torch.as_tensor(reward) | ||
self.next_state[self.count] = torch.as_tensor(next_state) | ||
self.done[self.count] = torch.as_tensor(done) | ||
|
||
|
||
self.count = (self.count + 1) % self.size | ||
self.real_size = min(self.size, self.real_size + 1) | ||
|
||
def sample(self, batch_size): | ||
assert self.real_size >= batch_size, "buffer contains less samples than batch size" | ||
|
||
sample_idxs, tree_idxs = [], [] | ||
priorities = torch.empty(batch_size, 1, dtype=torch.float) | ||
|
||
|
||
segment = self.tree.total / batch_size | ||
for i in range(batch_size): | ||
a, b = segment * i, segment * (i + 1) | ||
|
||
cumsum = random.uniform(a, b) | ||
|
||
tree_idx, priority, sample_idx = self.tree.get(cumsum) | ||
|
||
priorities[i] = priority | ||
tree_idxs.append(tree_idx) | ||
sample_idxs.append(sample_idx) | ||
|
||
probs = priorities / self.tree.total | ||
|
||
weights = (self.real_size * probs) ** -self.beta | ||
|
||
weights = weights / weights.max() | ||
batch = ( | ||
self.state[sample_idxs].to(self.device), | ||
self.action[sample_idxs].to(self.device), | ||
self.reward[sample_idxs].to(self.device), | ||
self.next_state[sample_idxs].to(self.device), | ||
self.done[sample_idxs].to(self.device) | ||
) | ||
return batch, weights, tree_idxs | ||
|
||
def update_priorities(self, data_idxs, priorities): | ||
if isinstance(priorities, torch.Tensor): | ||
priorities = priorities.detach().cpu().numpy() | ||
|
||
for data_idx, priority in zip(data_idxs, priorities): | ||
priority = (priority + self.eps) ** self.alpha | ||
self.tree.update(data_idx, priority) | ||
self.max_priority = max(self.max_priority, priority) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
from sumtree import SumTree | ||
import torch | ||
import random | ||
|
||
class PrioritizedSequenceReplayBuffer: | ||
def __init__(self,state_size,action_size,buffer_size,device,eps=1e-5,alpha=0.1,beta=0.1, | ||
decay_window=5, | ||
decay_coff=0.4, | ||
pre_priority=0.7): | ||
self.tree = SumTree(data_size=buffer_size) | ||
|
||
# PESR params | ||
self.eps = eps | ||
self.alpha = alpha | ||
self.beta = beta | ||
self.max_priority = 1. | ||
self.decay_window = decay_window | ||
self.decay_coff = decay_coff | ||
self.pre_priority = pre_priority | ||
|
||
# buffer params | ||
self.state = torch.empty(buffer_size, state_size, dtype=torch.float) | ||
self.action = torch.empty(buffer_size, action_size, dtype=torch.float) | ||
self.reward = torch.empty(buffer_size, dtype=torch.float) | ||
self.next_state = torch.empty(buffer_size, state_size, dtype=torch.float) | ||
self.done = torch.empty(buffer_size, dtype=torch.uint8) | ||
|
||
self.count = 0 | ||
self.real_size = 0 | ||
self.size = buffer_size | ||
|
||
# device | ||
self.device = device | ||
|
||
def add(self, transition): | ||
state, action, reward, next_state, done = transition | ||
|
||
# store transition index with maximum priority in sum tree | ||
self.tree.add(self.max_priority, self.count) | ||
|
||
# store transition in the buffer | ||
self.state[self.count] = torch.as_tensor(state) | ||
self.action[self.count] = torch.as_tensor(action) | ||
self.reward[self.count] = torch.as_tensor(reward) | ||
self.next_state[self.count] = torch.as_tensor(next_state) | ||
self.done[self.count] = torch.as_tensor(done) | ||
|
||
# update counters | ||
self.count = (self.count + 1) % self.size | ||
self.real_size = min(self.size, self.real_size + 1) | ||
|
||
def sample(self,batch_size): | ||
assert self.real_size >= batch_size, "buffer contains less samples than batch size" | ||
|
||
sample_idxs, tree_idxs = [], [] | ||
priorities = torch.empty(batch_size, 1, dtype=torch.float) | ||
|
||
segment = self.tree.total_priority / batch_size | ||
for i in range(batch_size): | ||
a, b = segment * i, segment * (i + 1) | ||
|
||
cumsum = random.uniform(a, b) | ||
# sample_idx is a sample index in buffer, needed further to sample actual transitions | ||
# tree_idx is a index of a sample in the tree, needed further to update priorities | ||
tree_idx, priority, sample_idx = self.tree.get(cumsum) | ||
|
||
priorities[i] = priority | ||
tree_idxs.append(tree_idx) | ||
sample_idxs.append(sample_idx) | ||
""" | ||
Note: | ||
The priorities stored in sumtree are all times alpha | ||
""" | ||
probs = priorities / self.tree.total_priority | ||
weights = (self.real_size * probs) ** -self.beta | ||
weights = weights / weights.max() | ||
batch = ( | ||
self.state[sample_idxs].to(self.device), | ||
self.action[sample_idxs].to(self.device), | ||
self.reward[sample_idxs].to(self.device), | ||
self.next_state[sample_idxs].to(self.device), | ||
self.done[sample_idxs].to(self.device) | ||
) | ||
return batch, weights, tree_idxs | ||
|
||
def update_priorities(self,data_idxs,abs_td_errors): | ||
""" | ||
when we get the TD-error, we should update the transition priority p_j | ||
And update decay_window's transition priorities | ||
""" | ||
if isinstance(abs_td_errors,torch.Tensor): | ||
abs_td_errors = abs_td_errors.detach().cpu().numpy() | ||
|
||
for data_idx, td_error in zip(data_idxs,abs_td_errors): | ||
# first update the batch: p_j | ||
# p_j <- max{|delta_j| + eps, pre_priority * p_j} | ||
old_priority = self.pre_priority * self.tree.nodes[data_idx + self.tree.size - 1] | ||
priority = (td_error + self.eps) ** self.alpha | ||
priority = max(priority,old_priority) | ||
self.tree.update(data_idx,priority) | ||
self.max_priority = max(self.max_priority,priority) | ||
|
||
# And then apply decay | ||
if self.count >= self.decay_window: | ||
# count points to the next position | ||
# count means the idx in the buffer and number of transition | ||
for i in reversed(range(self.decay_window)): | ||
idx = (self.count - i - 1) % self.size | ||
decayed_priority = priority * (self.decay_coff ** (i + 1)) | ||
tree_idx = idx + self.tree.size - 1 | ||
existing_priority = self.tree.nodes[tree_idx] | ||
self.tree.update(idx,max(decayed_priority,existing_priority)) |
Oops, something went wrong.