Skip to content

Commit

Permalink
Add a joint memory loader capable of sampling data from multiple memo…
Browse files Browse the repository at this point in the history
…ry loaders (#181)

This PR adds a `JointMemoryLoader` object that can sample data from
multiple different `MemoryLoader`s. This is needed for a lot of
imitation related training.
  • Loading branch information
klashenriksson authored Feb 13, 2024
1 parent 8cfdc55 commit 3ce45d1
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 0 deletions.
52 changes: 52 additions & 0 deletions emote/memory/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""
from __future__ import annotations

import collections
import inspect
import logging
import os
Expand Down Expand Up @@ -486,6 +487,57 @@ def __iter__(self):
yield {self.data_group: data, self.size_key: data[self.size_key]}


class JointMemoryLoader:
"""A memory loader capable of loading data from multiple `MemoryLoader`s."""

def __init__(self, loaders: list[MemoryLoader], size_key: str = "batch_size"):
self._loaders = loaders
self._size_key = size_key

counts = collections.Counter((loader.data_group for loader in loaders))
counts_over_1 = {k: count for k, count in counts.items() if count > 1}
if len(counts_over_1) != 0:
raise ValueError(
f"""JointMemoryLoader was provided MemoryLoaders that share the same datagroup. This will clobber the joint output data and is not allowed.
Here is a dict of each datagroup encountered more than once, and its occurance count: {counts_over_1}"""
)

def is_ready(self):
return all(loader.is_ready() for loader in self._loaders)

def __iter__(self):
if not self.is_ready():
raise RuntimeError(
"""memory loader(s) in JointMemoryLoader does not have enough data. Check `is_ready()`
before trying to iterate over data."""
)

while True:
out = {self._size_key: 0}

for loader in self._loaders:
data = next(iter(loader))
out[loader.data_group] = data[loader.data_group]
# for joint memory loaders we sum up all individual loader sizes
out[self._size_key] += data[loader.size_key]

yield out


class JointMemoryLoaderWithDataGroup(JointMemoryLoader):
"""A JointMemoryLoader that places its data inside of a user-specified datagroup."""

def __init__(self, loaders: list[MemoryLoader], data_group: str, size_key: str = "batch_size"):
super().__init__(loaders, size_key)
self._data_group = data_group

def __iter__(self):
data = next(super().__iter__())
total_size = data.pop(self._size_key)

yield {self._data_group: data, self._size_key: total_size}


class MemoryWarmup(Callback):
"""A blocker to ensure memory has data.
Expand Down
123 changes: 123 additions & 0 deletions tests/test_memory_loading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import numpy as np
import pytest

from emote.memory.column import Column
from emote.memory.fifo_strategy import FifoEjectionStrategy
from emote.memory.memory import JointMemoryLoader, JointMemoryLoaderWithDataGroup, MemoryLoader
from emote.memory.table import ArrayTable
from emote.memory.uniform_strategy import UniformSampleStrategy


@pytest.fixture
def a_dummy_table():
tab = ArrayTable(
columns=[Column("state", (), np.float32), Column("action", (), np.float32)],
maxlen=1_000,
sampler=UniformSampleStrategy(),
ejector=FifoEjectionStrategy(),
length_key="action",
device="cpu",
)
tab.add_sequence(
0,
{
"state": [5.0, 6.0],
"action": [1.0],
},
)

return tab


@pytest.fixture
def another_dummy_table():
tab = ArrayTable(
columns=[Column("state", (), np.float32), Column("action", (), np.float32)],
maxlen=1_000,
sampler=UniformSampleStrategy(),
ejector=FifoEjectionStrategy(),
length_key="action",
device="cpu",
)
tab.add_sequence(
0,
{
"state": [5.0, 6.0],
"action": [1.0],
},
)

return tab


def test_joint_memory_loader(a_dummy_table: ArrayTable, another_dummy_table: ArrayTable):
a_loader = MemoryLoader(
table=a_dummy_table,
rollout_count=1,
rollout_length=1,
size_key="batch_size",
data_group="a",
)
another_loader = MemoryLoader(
table=another_dummy_table,
rollout_count=1,
rollout_length=1,
size_key="batch_size",
data_group="another",
)

joint_loader = JointMemoryLoader(loaders=[a_loader, another_loader])

data = next(iter(joint_loader))
assert "a" in data and "another" in data, "JointMemoryLoader did not yield expected memory data"


def test_joint_memory_loader_datagroup(a_dummy_table: ArrayTable, another_dummy_table: ArrayTable):
a_loader = MemoryLoader(
table=a_dummy_table,
rollout_count=1,
rollout_length=1,
size_key="batch_size",
data_group="a",
)
another_loader = MemoryLoader(
table=another_dummy_table,
rollout_count=1,
rollout_length=1,
size_key="batch_size",
data_group="another",
)

joint_loader = JointMemoryLoaderWithDataGroup(
loaders=[a_loader, another_loader], data_group="joint_datagroup"
)

encapsulated_data = next(iter(joint_loader))
data = encapsulated_data["joint_datagroup"]

assert (
"joint_datagroup" in encapsulated_data
), "Expected joint dataloader to place data in its own datagroup, but it does not exist."
assert (
"a" in data and "another" in data
), "Expected joint dataloader to actually place data in its datagroup, but it is empty."


def test_joint_memory_loader_nonunique_loaders_trigger_exception(a_dummy_table: ArrayTable):
loader1 = MemoryLoader(
table=a_dummy_table,
rollout_count=1,
rollout_length=1,
size_key="batch_size",
data_group="a",
)
loader2 = MemoryLoader(
table=a_dummy_table,
rollout_count=1,
rollout_length=1,
size_key="batch_size",
data_group="a",
)

with pytest.raises(Exception, match="JointMemoryLoader"):
joint_loader = JointMemoryLoader([loader1, loader2]) # noqa

0 comments on commit 3ce45d1

Please sign in to comment.