Skip to content

Commit

Permalink
Add NextNElementWrapper virtual storage (#198)
Browse files Browse the repository at this point in the history
This PR adds a small virtual storage wrapper that samples the next `N`
elements ahead. This is needed for AMP training where the policy does
not execute actions at the same frequency as the imitation reference.
  • Loading branch information
klashenriksson authored Apr 12, 2024
1 parent b78dce7 commit 327ac2f
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 0 deletions.
46 changes: 46 additions & 0 deletions emote/memory/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,52 @@ def with_only_last(storage, shape, dtype):
return NextElementMapper(storage, shape, dtype, only_last=True)


class NextNElementWrapper(VirtualStorage):
"""Simple mapper that can be used to sample a specified N steps over"""

class Wrapper:
def __init__(self, item, n: int):
self._item = item
self._n = n

def __getitem__(self, key):
if isinstance(key, int):
key += self._n
elif isinstance(key, tuple):
key = tuple(k + self._n for k in key)
elif isinstance(key, slice):
key = slice(key.start + self._n, key.stop + self._n, key.step)
else:
raise ValueError(
f"Invalid indexing type '{type(key)}'. Only integer, tuple or slices supported."
)

return self._item[key]

@property
def shape(self):
return self._item.shape

def __init__(self, storage, shape, dtype):
super().__init__(storage, shape, dtype)
self._wrapper = NextNElementWrapper.Wrapper

def __getitem__(self, key: int | Tuple[int, ...] | slice):
return self._wrapper(self._storage[key], self._n)

def sequence_length_transform(self, length):
return length

@staticmethod
def with_n(n):
class NextNElementW(NextNElementWrapper):
def __init__(self, storage, shape, dtype):
super().__init__(storage, shape, dtype)
self._n = n

return NextNElementW


class SyntheticDones(VirtualStorage):
"""Generates done or masks based on sequence length."""

Expand Down
54 changes: 54 additions & 0 deletions tests/test_next_n_element_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import numpy as np
import pytest

from emote.memory.storage import NextNElementWrapper


@pytest.fixture
def storage() -> np.ndarray:
return np.arange(32).reshape((2, -1))


@pytest.mark.parametrize(
("batch_dim", "n"), ((0, 1), (0, 2), (0, 3), (0, 4), (1, 1), (1, 2), (1, 3), (1, 4))
)
def test_next_n_element_single(batch_dim, n, storage):
wrapper = NextNElementWrapper.with_n(n)(storage, (1,), np.float32)[batch_dim]

next_0 = wrapper[0]
next_1 = wrapper[1]
next_5 = wrapper[5]

assert next_0 == storage[batch_dim][n]
assert next_1 == storage[batch_dim][1 + n]
assert next_5 == storage[batch_dim][5 + n]


@pytest.mark.parametrize(
("batch_dim", "n"), ((0, 1), (0, 2), (0, 3), (0, 4), (1, 1), (1, 2), (1, 3), (1, 4))
)
def test_next_n_element_slice(batch_dim, n, storage):
wrapper = NextNElementWrapper.with_n(n)(storage, (1,), np.float32)[batch_dim]

next_0_to_2 = wrapper[0:2]
next_1_to_4 = wrapper[1:4]
next_2_to_5_skip_2 = wrapper[2:5:2]

assert np.all(next_0_to_2 == storage[batch_dim][n : (n + 2)])
assert np.all(next_1_to_4 == storage[batch_dim][(1 + n) : (4 + n)])
assert np.all(next_2_to_5_skip_2 == storage[batch_dim][(2 + n) : (5 + n) : 2])


@pytest.mark.parametrize(("batch_dim", "n"), ((0, 1), (0, 2), (1, 1), (1, 2)))
def test_next_n_element_tuple(batch_dim, n, storage):
storage = np.reshape(storage, (2, 4, 4))

wrapper = NextNElementWrapper.with_n(n)(storage, (4, 4), np.float32)[batch_dim]

next_0_0 = wrapper[(0, 0)]
next_1_0 = wrapper[(1, 0)]
next_1_1 = wrapper[(1, 1)]

assert np.all(next_0_0 == storage[batch_dim][(n, n)])
assert np.all(next_1_0 == storage[batch_dim][(1 + n, n)])
assert np.all(next_1_1 == storage[batch_dim][(1 + n, 1 + n)])

0 comments on commit 327ac2f

Please sign in to comment.