-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement Coverage Based Sampling Strategy (#178)
This PR adds a sampling strategy that prioritises samples that haven't been visited yet over samples that have. The chance of an episode being sampled is inversely proportional to the amount of times it has been sampled. This is additionally weighted by an `alpha` parameter, where a higher alpha means more weighting of less-sampled episodes. A test is also added to validate that recently added samples have a higher chance of being chosen.
- Loading branch information
Showing
4 changed files
with
206 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
""" | ||
""" | ||
|
||
import random | ||
|
||
from typing import Sequence | ||
|
||
import numpy as np | ||
|
||
from .core_types import SamplePoint | ||
from .strategy import EjectionStrategy, SampleStrategy, Strategy | ||
|
||
|
||
class CoverageBasedStrategy(Strategy): | ||
"""A sampler intended to sample based on coverage of experiences, | ||
favoring less-visited states. This base class can be used for implementing | ||
various coverage-based sampling strategies.""" | ||
|
||
def __init__(self, alpha=0.5): | ||
super().__init__() | ||
self._identities = {} | ||
self._sample_count = {} | ||
self._ids = [] | ||
self._prios = [] | ||
self._dirty = False | ||
self._alpha = alpha | ||
|
||
def track(self, identity: int, sequence_length: int): | ||
self._dirty = True | ||
self._identities[identity] = sequence_length | ||
self._sample_count[identity] = self._sample_count.get(identity, 0) | ||
|
||
def forget(self, identity: int): | ||
self._dirty = True | ||
del self._identities[identity] | ||
del self._sample_count[identity] | ||
|
||
def _rebalance(self): | ||
self._dirty = False | ||
original_prios = np.array(tuple(self._identities.values())) / sum(self._identities.values()) | ||
self._ids = np.array(tuple(self._identities.keys()), dtype=np.int64) | ||
|
||
sample_prios = np.array( | ||
[1 / (self._sample_count[id] + 1) ** self._alpha for id in self._ids] | ||
) | ||
combined_prios = original_prios * sample_prios | ||
|
||
sum_prios = sum(combined_prios) | ||
self._prios = combined_prios / sum_prios | ||
|
||
|
||
class CoverageBasedSampleStrategy(CoverageBasedStrategy, SampleStrategy): | ||
def __init__(self, alpha=0.5): | ||
super().__init__(alpha=alpha) | ||
|
||
def sample(self, count: int, transition_count: int) -> Sequence[SamplePoint]: | ||
if self._dirty: | ||
self._rebalance() | ||
|
||
identities = np.random.choice(self._ids, size=count, p=self._prios) | ||
ids = self._identities | ||
output = [] | ||
app = output.append | ||
r = random.random | ||
tm1 = transition_count - 1 | ||
for k in identities: | ||
self._sample_count[k] += 1 | ||
offset = int(r() * (ids[k] - tm1)) | ||
app((k, offset, offset + transition_count)) | ||
|
||
return output | ||
|
||
|
||
class CoverageBasedEjectionStrategy(CoverageBasedStrategy, EjectionStrategy): | ||
def sample(self, count: int) -> Sequence[int]: | ||
if self._dirty: | ||
self._rebalance() | ||
|
||
identities = set() | ||
while count > 0: | ||
identity = np.random.choice(self._ids, size=1, p=self._prios)[0] | ||
|
||
if identity in identities: | ||
continue | ||
|
||
length = self._identities[identity] | ||
count -= length | ||
identities.add(identity) | ||
|
||
return list(identities) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
""" | ||
Test to validate the behavior of `CoverageBasedSampleStrategy`. Tests how the `alpha` parameter influences the sampling distribution between two waves of data. | ||
Wave 1 and Wave 2: Two separate sets of data points added to the memory. After each wave, a series of samples are drawn from the memory. | ||
Alpha modulates how much the sampling prioritizes less-visited states. A higher alpha results in a stronger bias towards less-visited states. | ||
Intended Behavior: | ||
- alpha=0.0: Sampling should be approximately uniform, with no strong bias towards either wave. | ||
- alpha=1.0: Sampling should strongly prioritize the less-visited states (i.e., states from Wave 2 after it is added). | ||
- Intermediate alpha values (e.g., alpha=0.5) should result in intermediate behaviors. | ||
""" | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from emote.memory.builder import DictObsNStepTable | ||
from emote.memory.coverage_based_strategy import CoverageBasedSampleStrategy | ||
from emote.utils.spaces import BoxSpace, DictSpace, MDPSpace | ||
|
||
|
||
TABLE_MAX_LEN = 4096 | ||
SAMPLE_AMOUNT = 1024 | ||
ALPHAS = [0.0, 0.5, 1.0] | ||
SEQUENCE_LEN = 10 | ||
|
||
|
||
def create_sample_space() -> MDPSpace: | ||
reward_space = BoxSpace(dtype=np.float32, shape=(1,)) | ||
action_space = BoxSpace(dtype=np.int32, shape=(1,)) | ||
obs_space = BoxSpace(dtype=np.float32, shape=(2,)) | ||
state_space_dict = {"obs": obs_space} | ||
state_space = DictSpace(spaces=state_space_dict) | ||
return MDPSpace(rewards=reward_space, actions=action_space, state=state_space) | ||
|
||
|
||
def populate_table(table: DictObsNStepTable, sequence_len: int, start: int, end: int): | ||
for i in range(start, end): | ||
sequence = { | ||
"obs": [np.random.rand(2) for _ in range(sequence_len + 1)], | ||
"actions": [np.random.rand(1) for _ in range(sequence_len)], | ||
"rewards": [np.random.rand(1) for _ in range(sequence_len)], | ||
} | ||
|
||
table.add_sequence( | ||
identity=i, | ||
sequence=sequence, | ||
) | ||
|
||
|
||
def sample_table(table: DictObsNStepTable, sample_amount: int, count: int, sequence_length: int): | ||
for _ in range(sample_amount): | ||
table.sample(count, sequence_length) | ||
|
||
|
||
def test_memory_export(): | ||
device = torch.device("cpu") | ||
space = create_sample_space() | ||
for alpha in ALPHAS: | ||
table = DictObsNStepTable( | ||
spaces=space, | ||
use_terminal_column=False, | ||
maxlen=TABLE_MAX_LEN, | ||
sampler=CoverageBasedSampleStrategy(alpha=alpha), | ||
device=device, | ||
) | ||
|
||
wave_length = int(TABLE_MAX_LEN / (2 * SEQUENCE_LEN)) | ||
|
||
# Wave 1 | ||
populate_table(table=table, sequence_len=SEQUENCE_LEN, start=0, end=wave_length) | ||
sample_table(table=table, sample_amount=SAMPLE_AMOUNT, count=5, sequence_length=8) | ||
pre_second_wave_sample_counts = table._sampler._sample_count.copy() | ||
|
||
# Wave 2 | ||
populate_table( | ||
table=table, sequence_len=SEQUENCE_LEN, start=wave_length, end=wave_length * 2 | ||
) | ||
sample_table(table=table, sample_amount=SAMPLE_AMOUNT, count=5, sequence_length=8) | ||
|
||
second_wave_samples = sum( | ||
table._sampler._sample_count[id] - pre_second_wave_sample_counts.get(id, 0) | ||
for id in range(wave_length, wave_length * 2) | ||
) | ||
total_new_samples = sum( | ||
table._sampler._sample_count[id] - pre_second_wave_sample_counts.get(id, 0) | ||
for id in table._sampler._sample_count.keys() | ||
) | ||
|
||
proportion_second_wave = second_wave_samples / total_new_samples | ||
|
||
if alpha == 0.0: | ||
assert proportion_second_wave > 0.4 | ||
elif alpha == 0.5: | ||
assert proportion_second_wave > 0.6 | ||
elif alpha == 1.0: | ||
assert proportion_second_wave > 0.8 |