Skip to content

Commit

Permalink
Implement Coverage Based Sampling Strategy (#178)
Browse files Browse the repository at this point in the history
This PR adds a sampling strategy that prioritises samples that haven't
been visited yet over samples that have. The chance of an episode being
sampled is inversely proportional to the amount of times it has been
sampled. This is additionally weighted by an `alpha` parameter, where a
higher alpha means more weighting of less-sampled episodes.

A test is also added to validate that recently added samples have a
higher chance of being chosen.
  • Loading branch information
jaxs-ribs authored Oct 17, 2023
1 parent 589efe3 commit 46f4165
Show file tree
Hide file tree
Showing 4 changed files with 206 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ have new arguments to handle versioning.
- `OnnxExporter` accepts a `device` argument to enable tracing on other devices.
- `FinalRewardTestCheck` can now be configured with another key and to use windowed data.
- `begin_training` has been split into `restore_state` followed by `begin_training`
- `CoverageBasedSampleStrategy` has been added which allows memory sampling that prioritises unvisited experiences. This can speed up training.

### Deprecations

Expand Down
18 changes: 17 additions & 1 deletion emote/memory/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import numpy as np
import torch

from emote.memory.strategy import SampleStrategy

from ..utils import MDPSpace
from .adaptors import DictObsAdaptor, TerminalAdaptor
from .column import Column, TagColumn, VirtualColumn
Expand All @@ -25,8 +27,12 @@ def __init__(
columns: List[Column],
maxlen: int,
length_key="actions",
sampler: SampleStrategy = None,
device: torch.device,
):
if sampler is None:
sampler = UniformSampleStrategy()

adaptors = [DictObsAdaptor(obs_keys)]
if use_terminal_column:
columns.append(
Expand All @@ -41,7 +47,7 @@ def __init__(
super().__init__(
columns=columns,
maxlen=maxlen,
sampler=UniformSampleStrategy(),
sampler=sampler,
ejector=FifoEjectionStrategy(),
length_key=length_key,
adaptors=adaptors,
Expand All @@ -64,6 +70,7 @@ def __init__(
device: torch.device,
dones_dtype=bool,
masks_dtype=np.float32,
sampler: SampleStrategy = None,
):
if spaces.rewards is not None:
reward_column = Column(
Expand Down Expand Up @@ -113,11 +120,15 @@ def __init__(
]
)

if sampler is None:
sampler = UniformSampleStrategy()

super().__init__(
use_terminal_column=use_terminal_column,
maxlen=maxlen,
columns=columns,
obs_keys=obs_keys,
sampler=sampler,
device=device,
)

Expand All @@ -134,6 +145,7 @@ def __init__(
spaces: MDPSpace,
use_terminal_column: bool,
maxlen: int,
sampler: SampleStrategy = None,
device: torch.device,
):
if spaces.rewards is not None:
Expand Down Expand Up @@ -184,10 +196,14 @@ def __init__(
]
)

if sampler is None:
sampler = UniformSampleStrategy()

super().__init__(
use_terminal_column=use_terminal_column,
maxlen=maxlen,
columns=columns,
obs_keys=obs_keys,
sampler=sampler,
device=device,
)
91 changes: 91 additions & 0 deletions emote/memory/coverage_based_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
"""

import random

from typing import Sequence

import numpy as np

from .core_types import SamplePoint
from .strategy import EjectionStrategy, SampleStrategy, Strategy


class CoverageBasedStrategy(Strategy):
"""A sampler intended to sample based on coverage of experiences,
favoring less-visited states. This base class can be used for implementing
various coverage-based sampling strategies."""

def __init__(self, alpha=0.5):
super().__init__()
self._identities = {}
self._sample_count = {}
self._ids = []
self._prios = []
self._dirty = False
self._alpha = alpha

def track(self, identity: int, sequence_length: int):
self._dirty = True
self._identities[identity] = sequence_length
self._sample_count[identity] = self._sample_count.get(identity, 0)

def forget(self, identity: int):
self._dirty = True
del self._identities[identity]
del self._sample_count[identity]

def _rebalance(self):
self._dirty = False
original_prios = np.array(tuple(self._identities.values())) / sum(self._identities.values())
self._ids = np.array(tuple(self._identities.keys()), dtype=np.int64)

sample_prios = np.array(
[1 / (self._sample_count[id] + 1) ** self._alpha for id in self._ids]
)
combined_prios = original_prios * sample_prios

sum_prios = sum(combined_prios)
self._prios = combined_prios / sum_prios


class CoverageBasedSampleStrategy(CoverageBasedStrategy, SampleStrategy):
def __init__(self, alpha=0.5):
super().__init__(alpha=alpha)

def sample(self, count: int, transition_count: int) -> Sequence[SamplePoint]:
if self._dirty:
self._rebalance()

identities = np.random.choice(self._ids, size=count, p=self._prios)
ids = self._identities
output = []
app = output.append
r = random.random
tm1 = transition_count - 1
for k in identities:
self._sample_count[k] += 1
offset = int(r() * (ids[k] - tm1))
app((k, offset, offset + transition_count))

return output


class CoverageBasedEjectionStrategy(CoverageBasedStrategy, EjectionStrategy):
def sample(self, count: int) -> Sequence[int]:
if self._dirty:
self._rebalance()

identities = set()
while count > 0:
identity = np.random.choice(self._ids, size=1, p=self._prios)[0]

if identity in identities:
continue

length = self._identities[identity]
count -= length
identities.add(identity)

return list(identities)
97 changes: 97 additions & 0 deletions tests/test_memory_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
Test to validate the behavior of `CoverageBasedSampleStrategy`. Tests how the `alpha` parameter influences the sampling distribution between two waves of data.
Wave 1 and Wave 2: Two separate sets of data points added to the memory. After each wave, a series of samples are drawn from the memory.
Alpha modulates how much the sampling prioritizes less-visited states. A higher alpha results in a stronger bias towards less-visited states.
Intended Behavior:
- alpha=0.0: Sampling should be approximately uniform, with no strong bias towards either wave.
- alpha=1.0: Sampling should strongly prioritize the less-visited states (i.e., states from Wave 2 after it is added).
- Intermediate alpha values (e.g., alpha=0.5) should result in intermediate behaviors.
"""

import numpy as np
import torch

from emote.memory.builder import DictObsNStepTable
from emote.memory.coverage_based_strategy import CoverageBasedSampleStrategy
from emote.utils.spaces import BoxSpace, DictSpace, MDPSpace


TABLE_MAX_LEN = 4096
SAMPLE_AMOUNT = 1024
ALPHAS = [0.0, 0.5, 1.0]
SEQUENCE_LEN = 10


def create_sample_space() -> MDPSpace:
reward_space = BoxSpace(dtype=np.float32, shape=(1,))
action_space = BoxSpace(dtype=np.int32, shape=(1,))
obs_space = BoxSpace(dtype=np.float32, shape=(2,))
state_space_dict = {"obs": obs_space}
state_space = DictSpace(spaces=state_space_dict)
return MDPSpace(rewards=reward_space, actions=action_space, state=state_space)


def populate_table(table: DictObsNStepTable, sequence_len: int, start: int, end: int):
for i in range(start, end):
sequence = {
"obs": [np.random.rand(2) for _ in range(sequence_len + 1)],
"actions": [np.random.rand(1) for _ in range(sequence_len)],
"rewards": [np.random.rand(1) for _ in range(sequence_len)],
}

table.add_sequence(
identity=i,
sequence=sequence,
)


def sample_table(table: DictObsNStepTable, sample_amount: int, count: int, sequence_length: int):
for _ in range(sample_amount):
table.sample(count, sequence_length)


def test_memory_export():
device = torch.device("cpu")
space = create_sample_space()
for alpha in ALPHAS:
table = DictObsNStepTable(
spaces=space,
use_terminal_column=False,
maxlen=TABLE_MAX_LEN,
sampler=CoverageBasedSampleStrategy(alpha=alpha),
device=device,
)

wave_length = int(TABLE_MAX_LEN / (2 * SEQUENCE_LEN))

# Wave 1
populate_table(table=table, sequence_len=SEQUENCE_LEN, start=0, end=wave_length)
sample_table(table=table, sample_amount=SAMPLE_AMOUNT, count=5, sequence_length=8)
pre_second_wave_sample_counts = table._sampler._sample_count.copy()

# Wave 2
populate_table(
table=table, sequence_len=SEQUENCE_LEN, start=wave_length, end=wave_length * 2
)
sample_table(table=table, sample_amount=SAMPLE_AMOUNT, count=5, sequence_length=8)

second_wave_samples = sum(
table._sampler._sample_count[id] - pre_second_wave_sample_counts.get(id, 0)
for id in range(wave_length, wave_length * 2)
)
total_new_samples = sum(
table._sampler._sample_count[id] - pre_second_wave_sample_counts.get(id, 0)
for id in table._sampler._sample_count.keys()
)

proportion_second_wave = second_wave_samples / total_new_samples

if alpha == 0.0:
assert proportion_second_wave > 0.4
elif alpha == 0.5:
assert proportion_second_wave > 0.6
elif alpha == 1.0:
assert proportion_second_wave > 0.8

0 comments on commit 46f4165

Please sign in to comment.