From 3ce45d1a0616a8c7edd09662259060a744aef1d0 Mon Sep 17 00:00:00 2001 From: klashenriksson <6956114+klashenriksson@users.noreply.github.com> Date: Tue, 13 Feb 2024 11:31:37 +0100 Subject: [PATCH] Add a joint memory loader capable of sampling data from multiple memory loaders (#181) This PR adds a `JointMemoryLoader` object that can sample data from multiple different `MemoryLoader`s. This is needed for a lot of imitation related training. --- emote/memory/memory.py | 52 +++++++++++++++ tests/test_memory_loading.py | 123 +++++++++++++++++++++++++++++++++++ 2 files changed, 175 insertions(+) create mode 100644 tests/test_memory_loading.py diff --git a/emote/memory/memory.py b/emote/memory/memory.py index f48312d5..94e74495 100644 --- a/emote/memory/memory.py +++ b/emote/memory/memory.py @@ -7,6 +7,7 @@ """ from __future__ import annotations +import collections import inspect import logging import os @@ -486,6 +487,57 @@ def __iter__(self): yield {self.data_group: data, self.size_key: data[self.size_key]} +class JointMemoryLoader: + """A memory loader capable of loading data from multiple `MemoryLoader`s.""" + + def __init__(self, loaders: list[MemoryLoader], size_key: str = "batch_size"): + self._loaders = loaders + self._size_key = size_key + + counts = collections.Counter((loader.data_group for loader in loaders)) + counts_over_1 = {k: count for k, count in counts.items() if count > 1} + if len(counts_over_1) != 0: + raise ValueError( + f"""JointMemoryLoader was provided MemoryLoaders that share the same datagroup. This will clobber the joint output data and is not allowed. + Here is a dict of each datagroup encountered more than once, and its occurance count: {counts_over_1}""" + ) + + def is_ready(self): + return all(loader.is_ready() for loader in self._loaders) + + def __iter__(self): + if not self.is_ready(): + raise RuntimeError( + """memory loader(s) in JointMemoryLoader does not have enough data. Check `is_ready()` + before trying to iterate over data.""" + ) + + while True: + out = {self._size_key: 0} + + for loader in self._loaders: + data = next(iter(loader)) + out[loader.data_group] = data[loader.data_group] + # for joint memory loaders we sum up all individual loader sizes + out[self._size_key] += data[loader.size_key] + + yield out + + +class JointMemoryLoaderWithDataGroup(JointMemoryLoader): + """A JointMemoryLoader that places its data inside of a user-specified datagroup.""" + + def __init__(self, loaders: list[MemoryLoader], data_group: str, size_key: str = "batch_size"): + super().__init__(loaders, size_key) + self._data_group = data_group + + def __iter__(self): + data = next(super().__iter__()) + total_size = data.pop(self._size_key) + + yield {self._data_group: data, self._size_key: total_size} + + class MemoryWarmup(Callback): """A blocker to ensure memory has data. diff --git a/tests/test_memory_loading.py b/tests/test_memory_loading.py new file mode 100644 index 00000000..6598eacc --- /dev/null +++ b/tests/test_memory_loading.py @@ -0,0 +1,123 @@ +import numpy as np +import pytest + +from emote.memory.column import Column +from emote.memory.fifo_strategy import FifoEjectionStrategy +from emote.memory.memory import JointMemoryLoader, JointMemoryLoaderWithDataGroup, MemoryLoader +from emote.memory.table import ArrayTable +from emote.memory.uniform_strategy import UniformSampleStrategy + + +@pytest.fixture +def a_dummy_table(): + tab = ArrayTable( + columns=[Column("state", (), np.float32), Column("action", (), np.float32)], + maxlen=1_000, + sampler=UniformSampleStrategy(), + ejector=FifoEjectionStrategy(), + length_key="action", + device="cpu", + ) + tab.add_sequence( + 0, + { + "state": [5.0, 6.0], + "action": [1.0], + }, + ) + + return tab + + +@pytest.fixture +def another_dummy_table(): + tab = ArrayTable( + columns=[Column("state", (), np.float32), Column("action", (), np.float32)], + maxlen=1_000, + sampler=UniformSampleStrategy(), + ejector=FifoEjectionStrategy(), + length_key="action", + device="cpu", + ) + tab.add_sequence( + 0, + { + "state": [5.0, 6.0], + "action": [1.0], + }, + ) + + return tab + + +def test_joint_memory_loader(a_dummy_table: ArrayTable, another_dummy_table: ArrayTable): + a_loader = MemoryLoader( + table=a_dummy_table, + rollout_count=1, + rollout_length=1, + size_key="batch_size", + data_group="a", + ) + another_loader = MemoryLoader( + table=another_dummy_table, + rollout_count=1, + rollout_length=1, + size_key="batch_size", + data_group="another", + ) + + joint_loader = JointMemoryLoader(loaders=[a_loader, another_loader]) + + data = next(iter(joint_loader)) + assert "a" in data and "another" in data, "JointMemoryLoader did not yield expected memory data" + + +def test_joint_memory_loader_datagroup(a_dummy_table: ArrayTable, another_dummy_table: ArrayTable): + a_loader = MemoryLoader( + table=a_dummy_table, + rollout_count=1, + rollout_length=1, + size_key="batch_size", + data_group="a", + ) + another_loader = MemoryLoader( + table=another_dummy_table, + rollout_count=1, + rollout_length=1, + size_key="batch_size", + data_group="another", + ) + + joint_loader = JointMemoryLoaderWithDataGroup( + loaders=[a_loader, another_loader], data_group="joint_datagroup" + ) + + encapsulated_data = next(iter(joint_loader)) + data = encapsulated_data["joint_datagroup"] + + assert ( + "joint_datagroup" in encapsulated_data + ), "Expected joint dataloader to place data in its own datagroup, but it does not exist." + assert ( + "a" in data and "another" in data + ), "Expected joint dataloader to actually place data in its datagroup, but it is empty." + + +def test_joint_memory_loader_nonunique_loaders_trigger_exception(a_dummy_table: ArrayTable): + loader1 = MemoryLoader( + table=a_dummy_table, + rollout_count=1, + rollout_length=1, + size_key="batch_size", + data_group="a", + ) + loader2 = MemoryLoader( + table=a_dummy_table, + rollout_count=1, + rollout_length=1, + size_key="batch_size", + data_group="a", + ) + + with pytest.raises(Exception, match="JointMemoryLoader"): + joint_loader = JointMemoryLoader([loader1, loader2]) # noqa