From 1f7f524e8eeb723733745aff383855526c426223 Mon Sep 17 00:00:00 2001 From: Andi Peng Date: Thu, 16 May 2024 18:22:07 -0400 Subject: [PATCH] first commit --- predicators/envs/spot_env.py | 53 +++++++++++++++++++ .../ground_truth_models/spot_env/nsrts.py | 2 +- .../ground_truth_models/spot_env/options.py | 1 + predicators/perception/spot_perceiver.py | 29 ++++++---- .../graph_nav_maps/b45-621/metadata.yaml | 12 ++++- predicators/utils.py | 1 + 6 files changed, 84 insertions(+), 14 deletions(-) diff --git a/predicators/envs/spot_env.py b/predicators/envs/spot_env.py index 60ab8e33d2..57b14d1f8c 100644 --- a/predicators/envs/spot_env.py +++ b/predicators/envs/spot_env.py @@ -3079,3 +3079,56 @@ def _generate_goal_description(self) -> GoalDescription: def _get_dry_task(self, train_or_test: str, task_idx: int) -> EnvironmentTask: raise NotImplementedError("Dry task generation not implemented.") + + +############################################################################### +# Test plant demo # +############################################################################### + + +class TestPlantEnv(SpotRearrangementEnv): + """TODO; basic demo + """ + + def __init__(self, use_gui: bool = True) -> None: + super().__init__(use_gui) + + op_to_name = {o.name: o for o in _create_operators()} + op_names_to_keep = { + "MoveToReachObject", + "MoveToHandViewObject", + "PickObjectFromTop", + "PlaceObjectOnTop", + } + self._strips_operators = {op_to_name[o] for o in op_names_to_keep} + + @classmethod + def get_name(cls) -> str: + return "plant_test_env" + + @property + def _detection_id_to_obj(self) -> Dict[ObjectDetectionID, Object]: + + detection_id_to_obj: Dict[ObjectDetectionID, Object] = {} + + green_apple = Object("green_apple", _movable_object_type) + green_apple_detection = LanguageObjectDetectionID( + "green apple/tennis ball") + detection_id_to_obj[green_apple_detection] = green_apple + plant = Object("plant", _immovable_object_type) + plant_detection = LanguageObjectDetectionID( + "potted plant") + detection_id_to_obj[plant_detection] = plant + + for obj, pose in get_known_immovable_objects().items(): + detection_id = KnownStaticObjectDetectionID(obj.name, pose) + detection_id_to_obj[detection_id] = obj + + return detection_id_to_obj + + def _generate_goal_description(self) -> GoalDescription: + return "place the green apple on the plant" + + def _get_dry_task(self, train_or_test: str, + task_idx: int) -> EnvironmentTask: + raise NotImplementedError("Dry task generation not implemented.") diff --git a/predicators/ground_truth_models/spot_env/nsrts.py b/predicators/ground_truth_models/spot_env/nsrts.py index 8ab36470a6..245a8066e6 100644 --- a/predicators/ground_truth_models/spot_env/nsrts.py +++ b/predicators/ground_truth_models/spot_env/nsrts.py @@ -288,7 +288,7 @@ def get_env_names(cls) -> Set[str]: "spot_cube_env", "spot_soda_floor_env", "spot_soda_table_env", "spot_soda_bucket_env", "spot_soda_chair_env", "spot_main_sweep_env", "spot_ball_and_cup_sticky_table_env", - "spot_brush_shelf_env", "lis_spot_block_floor_env" + "spot_brush_shelf_env", "lis_spot_block_floor_env", "plant_test_env" } @staticmethod diff --git a/predicators/ground_truth_models/spot_env/options.py b/predicators/ground_truth_models/spot_env/options.py index 4729a50493..1ae19b48dc 100644 --- a/predicators/ground_truth_models/spot_env/options.py +++ b/predicators/ground_truth_models/spot_env/options.py @@ -996,6 +996,7 @@ def get_env_names(cls) -> Set[str]: "spot_ball_and_cup_sticky_table_env", "spot_brush_shelf_env", "lis_spot_block_floor_env", + "plant_test_env" } @classmethod diff --git a/predicators/perception/spot_perceiver.py b/predicators/perception/spot_perceiver.py index d39107a060..a462776c81 100644 --- a/predicators/perception/spot_perceiver.py +++ b/predicators/perception/spot_perceiver.py @@ -269,17 +269,17 @@ def _create_state(self) -> State: } # Uncomment for debugging. - # logging.info("Percept state:") - # logging.info(percept_state.pretty_str()) - # logging.info("Percept atoms:") - # atom_str = "\n".join( - # map( - # str, - # sorted(utils.abstract(percept_state, - # self._percept_predicates)))) - # logging.info(atom_str) - # logging.info("Simulator state:") - # logging.info(simulator_state) + logging.info("Percept state:") + logging.info(percept_state.pretty_str()) + logging.info("Percept atoms:") + atom_str = "\n".join( + map( + str, + sorted(utils.abstract(percept_state, + self._percept_predicates)))) + logging.info(atom_str) + logging.info("Simulator state:") + logging.info(simulator_state) # Now finish the state. state = _PartialPerceptionState(percept_state.data, @@ -500,6 +500,13 @@ def _create_goal(self, state: State, GroundAtom(ContainerReadyForSweeping, [bucket, black_table]), GroundAtom(IsSweeper, [brush]) } + if goal_description == "place the green apple on the plant": + plant = Object("plant", _immovable_object_type) + apple = Object("green_apple", _movable_object_type) + On = pred_name_to_pred["On"] + return { + GroundAtom(On, [apple, plant]), + } raise NotImplementedError("Unrecognized goal description") def render_mental_images(self, observation: Observation, diff --git a/predicators/spot_utils/graph_nav_maps/b45-621/metadata.yaml b/predicators/spot_utils/graph_nav_maps/b45-621/metadata.yaml index cb37e2d951..771bd8e237 100644 --- a/predicators/spot_utils/graph_nav_maps/b45-621/metadata.yaml +++ b/predicators/spot_utils/graph_nav_maps/b45-621/metadata.yaml @@ -40,10 +40,18 @@ static-object-features: length: 10000000 # effectively infinite width: 10000000 flat_top_surface: 1 - red_block: + green_apple: shape: 2 height: 0.1 length: 0.1 width: 0.1 placeable: 1 - is_sweeper: 0 \ No newline at end of file + is_sweeper: 0 + plant: + shape: 2 + height: 0.25 + length: 0.25 + width: 0.25 + placeable: 1 + is_sweeper: 0 + flat_top_surface: 1 \ No newline at end of file diff --git a/predicators/utils.py b/predicators/utils.py index 7293848f78..ca2260edf2 100644 --- a/predicators/utils.py +++ b/predicators/utils.py @@ -526,6 +526,7 @@ def sample_random_point( self, rng: np.random.Generator, min_dist_from_edge: float = 0.0) -> Tuple[float, float]: + import ipdb; ipdb.set_trace() assert min_dist_from_edge < self.radius, "min_dist_from_edge is " + \ "greater than radius" rand_mag = rng.uniform(0, self.radius - min_dist_from_edge)