Skip to content

Commit

Permalink
Merge pull request #199 from ARISE-Initiative/obs_register_fix
Browse files Browse the repository at this point in the history
Obs register fix
  • Loading branch information
danfeiX authored Oct 8, 2024
2 parents 29d6ca2 + 07dc68a commit 9273f9c
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
15 changes: 15 additions & 0 deletions robomimic/models/base_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torchvision import models as vision_models

import robomimic.utils.tensor_utils as TensorUtils
import robomimic.utils.obs_utils as ObsUtils


CONV_ACTIVATIONS = {
Expand Down Expand Up @@ -463,6 +464,20 @@ class ConvBase(Module):
def __init__(self):
super(ConvBase, self).__init__()

def __init_subclass__(cls, **kwargs):
"""
Hook method to automatically register all valid subclasses so we can keep track of valid observation encoders
in a global dict.
This global dict stores mapping from observation encoder network name to class.
We keep track of these registries to enable automated class inference at runtime, allowing
users to simply extend our base encoder class and refer to that class in string form
in their config, without having to manually register their class internally.
This also future-proofs us for any additional encoder classes we would
like to add ourselves.
"""
ObsUtils.register_encoder_backbone(cls)

# dirty hack - re-implement to pass the buck onto subclasses from ABC parent
def output_shape(self, input_shape):
"""
Expand Down
2 changes: 1 addition & 1 deletion robomimic/models/obs_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(

# extract only relevant kwargs for this specific backbone
backbone_kwargs = extract_class_init_kwargs_from_dict(
cls = ObsUtils.OBS_ENCODER_CORES[backbone_class],
cls = ObsUtils.OBS_ENCODER_BACKBONES[backbone_class],
dic=backbone_kwargs, copy=True)

# visual backbone
Expand Down
9 changes: 7 additions & 2 deletions robomimic/utils/obs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@
# in their config, without having to manually register their class internally.
# This also future-proofs us for any additional encoder / randomizer classes we would
# like to add ourselves.
OBS_ENCODER_CORES = {"None": None} # Include default None
OBS_RANDOMIZERS = {"None": None} # Include default None
OBS_ENCODER_CORES = {"None": None} # Per-modality core net as defined in obs_cores.py, e.g., "VisualCore"
OBS_RANDOMIZERS = {"None": None} # Obs randomizer defined in obs_cores.py, e.g., "CropRandomizer"
OBS_ENCODER_BACKBONES = {"None": None} # Architecture backbones for encoding obervation, e.g., "ResNet18Conv"


def register_obs_key(target_class):
Expand All @@ -59,6 +60,10 @@ def register_randomizer(target_class):
assert target_class not in OBS_RANDOMIZERS, f"Already registered obs randomizer {target_class}!"
OBS_RANDOMIZERS[target_class.__name__] = target_class

def register_encoder_backbone(target_class):
assert target_class not in OBS_ENCODER_BACKBONES, f"Already registered obs encoder backbone {target_class}!"
OBS_ENCODER_BACKBONES[target_class.__name__] = target_class


class ObservationKeyToModalityDict(dict):
"""
Expand Down

0 comments on commit 9273f9c

Please sign in to comment.