From cf9a52bfdb2ff8d868c873d6d0c92f026d553df2 Mon Sep 17 00:00:00 2001 From: yichao-liang Date: Wed, 1 Jan 2025 19:30:07 +0000 Subject: [PATCH] can make circuit demos --- .../approaches/random_actions_approach.py | 1 - .../envs/assets/urdf/bulb_box_snap.urdf | 6 +- .../envs/assets/urdf/snap_connector4.urdf | 14 +- predicators/envs/pybullet_circuit.py | 54 +++- predicators/envs/pybullet_grow.py | 8 +- .../ground_truth_models/circuit/nsrts.py | 34 +-- .../ground_truth_models/circuit/options.py | 230 +++++++++++++++++- 7 files changed, 301 insertions(+), 46 deletions(-) diff --git a/predicators/approaches/random_actions_approach.py b/predicators/approaches/random_actions_approach.py index 20459d5b73..e129a0f885 100644 --- a/predicators/approaches/random_actions_approach.py +++ b/predicators/approaches/random_actions_approach.py @@ -20,7 +20,6 @@ def is_learning_based(self) -> bool: def _solve(self, task: Task, timeout: int) -> Callable[[State], Action]: def _policy(_: State) -> Action: - breakpoint() return Action(self._action_space.sample()) return _policy diff --git a/predicators/envs/assets/urdf/bulb_box_snap.urdf b/predicators/envs/assets/urdf/bulb_box_snap.urdf index e5d6c9411f..f6d457890f 100644 --- a/predicators/envs/assets/urdf/bulb_box_snap.urdf +++ b/predicators/envs/assets/urdf/bulb_box_snap.urdf @@ -1535,21 +1535,21 @@ - + - + - + diff --git a/predicators/envs/assets/urdf/snap_connector4.urdf b/predicators/envs/assets/urdf/snap_connector4.urdf index ac46bae7cc..6690b5af5d 100644 --- a/predicators/envs/assets/urdf/snap_connector4.urdf +++ b/predicators/envs/assets/urdf/snap_connector4.urdf @@ -6,7 +6,7 @@ - + @@ -15,7 +15,7 @@ - + @@ -64,7 +64,7 @@ - + @@ -73,7 +73,7 @@ - + @@ -84,21 +84,21 @@ - + - + - + \ No newline at end of file diff --git a/predicators/envs/pybullet_circuit.py b/predicators/envs/pybullet_circuit.py index 3eea9e770c..4eeba7f4c4 100644 --- a/predicators/envs/pybullet_circuit.py +++ b/predicators/envs/pybullet_circuit.py @@ -3,6 +3,11 @@ In the simplest case, the lightbulb is automatically turned on when the light is connected to both the positive and negative terminals of the battery. The lightbulb and the battery are fixed, the wire is moveable. + +python predicators/main.py --approach oracle --env pybullet_circuit \ +--seed 0 --num_test_tasks 1 --use_gui --debug --num_train_tasks 0 \ +--sesame_max_skeletons_optimized 1 --make_failure_videos --video_fps 20 \ +--pybullet_camera_height 900 --pybullet_camera_width 900 --debug """ import logging @@ -55,6 +60,7 @@ class PyBulletCircuitEnv(PyBulletEnv): [0.0, 0.0, np.pi / 2]) robot_init_tilt: ClassVar[float] = np.pi / 2 robot_init_wrist: ClassVar[float] = -np.pi / 2 + max_angular_vel: ClassVar[float] = np.pi / 4 # Hard-coded finger states for open/close open_fingers: ClassVar[float] = 0.4 @@ -64,7 +70,7 @@ class PyBulletCircuitEnv(PyBulletEnv): _bulb_on_color: ClassVar[Tuple[float, float, float, float]] = (1.0, 1.0, 0.0, 1.0) # yellow _bulb_off_color: ClassVar[Tuple[float, float, float, - float]] = (1.0, 1.0, 1.0, 1.0) # white + float]] = (0.8, 0.8, 0.8, 1.0) # white # Connector dimensions snap_width: ClassVar[float] = 0.05 @@ -76,13 +82,13 @@ class PyBulletCircuitEnv(PyBulletEnv): # Camera parameters _camera_distance: ClassVar[float] = 1.3 _camera_yaw: ClassVar[float] = 70 - _camera_pitch: ClassVar[float] = -38 + _camera_pitch: ClassVar[float] = -50 _camera_target: ClassVar[Pose3D] = (0.75, 1.25, 0.42) # --- CHANGED / ADDED --- # Added "rot" to both the battery and light types. _robot_type = Type("robot", ["x", "y", "z", "fingers", "tilt", "wrist"]) - _wire_type = Type("wire", ["x", "y", "z", "rot"]) + _wire_type = Type("wire", ["x", "y", "z", "rot", "is_held"]) _battery_type = Type("battery", ["x", "y", "z", "rot"]) _light_type = Type("light", ["x", "y", "z", "rot", "is_on"]) @@ -104,6 +110,11 @@ def __init__(self, use_gui: bool = True) -> None: # self._Connected = Predicate("Connected", # [self._light_type, self._battery_type], # self._Connected_holds) + self._Holding = Predicate("Holding", + [self._robot_type, self._wire_type], + self._Holding_holds) + self._HandEmpty = Predicate("HandEmpty", [self._robot_type], + self._HandEmpty_holds) self._ConnectedToLight = Predicate("ConnectedToLight", [self._wire_type, self._light_type], self._ConnectedToLight_holds) @@ -117,7 +128,8 @@ def __init__(self, use_gui: bool = True) -> None: # connected to the battery. # Normal version used in the simulator - self._CircuitClosed = Predicate("CircuitClosed", [], + self._CircuitClosed = Predicate("CircuitClosed", + [self._light_type, self._battery_type], self._CircuitClosed_holds) # self._CircuitClosed_abs = ConceptPredicate("CircuitClosed", # [self._wire_type, self._wire_type], @@ -134,6 +146,8 @@ def predicates(self) -> Set[Predicate]: return { # If you want to define self._Connected, re-add it here # self._Connected, + self._Holding, + self._HandEmpty, self._LightOn, self._ConnectedToLight, self._ConnectedToBattery, @@ -265,11 +279,13 @@ def _get_state(self) -> State: for wire_obj in [self._wire1, self._wire2]: (wx, wy, wz), orn = p.getBasePositionAndOrientation( wire_obj.id, physicsClientId=self._physics_client_id) + is_held_val = 1.0 if wire_obj.id == self._held_obj_id else 0.0 state_dict[wire_obj] = { "x": wx, "y": wy, "z": wz, "rot": p.getEulerFromQuaternion(orn)[2], + "is_held": is_held_val } # Convert dictionary to a PyBulletState @@ -319,6 +335,9 @@ def _reset_state(self, state: State) -> None: position=(wx, wy, wz), orientation=p.getQuaternionFromEuler([0, 0, rot]), physics_client_id=self._physics_client_id) + if state.get(wire_obj, "is_held") > 0.5: + self._attach(wire_obj.id, self._pybullet_robot) + self._held_obj_id = wire_obj.id # Check if re-creation matches reconstructed_state = self._get_state() @@ -344,6 +363,16 @@ def step(self, action: Action, render_obs: bool = False) -> State: # ------------------------------------------------------------------------- # Predicates + @staticmethod + def _Holding_holds(state: State, objects: Sequence[Object]) -> bool: + _, wire = objects + return state.get(wire, "is_held") > 0.5 + + @staticmethod + def _HandEmpty_holds(state: State, objects: Sequence[Object]) -> bool: + robot, = objects + return state.get(robot, "fingers") > 0.2 + @staticmethod def _ConnectedToLight_holds(state: State, objects: Sequence[Object]) -> bool: @@ -366,7 +395,7 @@ def _ConnectedToLight_holds(state: State, return False # Correct x and y differences for connection - target_x_diff = PyBulletCircuitEnv.bulb_snap_length / 2 - \ + target_x_diff = PyBulletCircuitEnv.wire_snap_length / 2 - \ PyBulletCircuitEnv.snap_width / 2 target_y_diff = PyBulletCircuitEnv.bulb_snap_length / 2 + \ PyBulletCircuitEnv.snap_width / 2 @@ -443,7 +472,7 @@ def _turn_bulb_on(self) -> None: if self._light.id is not None: p.changeVisualShape( self._light.id, - -1, # all link indices + 3, # all link indices rgbaColor=self._bulb_on_color, physicsClientId=self._physics_client_id) @@ -451,7 +480,7 @@ def _turn_bulb_off(self) -> None: if self._light.id is not None: p.changeVisualShape( self._light.id, - -1, # all link indices + 3, # all link indices rgbaColor=self._bulb_off_color, physicsClientId=self._physics_client_id) @@ -484,7 +513,7 @@ def _make_tasks(self, num_tasks: int, # For randomization, tweak or keep rot=0.0 as needed battery_dict = { "x": battery_x, - "y": 1.35, + "y": 1.3, "z": self.z_lb + self.snap_height / 2, "rot": np.pi / 2, } @@ -495,12 +524,14 @@ def _make_tasks(self, num_tasks: int, "y": 1.15, # lower region "z": self.z_lb + self.snap_height / 2, "rot": 0.0, + "is_held": 0.0, } wire2_dict = { "x": 0.75, - "y": 1.55, # upper region + "y": self.y_ub - self.init_padding * 3, # upper region "z": self.z_lb + self.snap_height / 2, "rot": 0.0, + "is_held": 0.0, } # Light near upper region @@ -508,7 +539,7 @@ def _make_tasks(self, num_tasks: int, # For randomization, tweak or keep rot=0.0 as needed light_dict = { "x": bulb_x, - "y": 1.35, + "y": 1.3, "z": self.z_lb + self.snap_height / 2, "rot": -np.pi / 2, "is_on": 0.0, @@ -525,7 +556,8 @@ def _make_tasks(self, num_tasks: int, # The goal can be that the light is on. goal_atoms = { - GroundAtom(self._LightOn, [self._light]), + # GroundAtom(self._LightOn, [self._light]), + GroundAtom(self._CircuitClosed, [self._light, self._battery]), } tasks.append(EnvironmentTask(init_state, goal_atoms)) diff --git a/predicators/envs/pybullet_grow.py b/predicators/envs/pybullet_grow.py index 9da9e812a1..2da0d3b135 100644 --- a/predicators/envs/pybullet_grow.py +++ b/predicators/envs/pybullet_grow.py @@ -1,5 +1,5 @@ -"""python predicators/main.py --approach oracle --env pybullet_grow --seed 1 \ - +""" +python predicators/main.py --approach oracle --env pybullet_grow --seed 1 \ --num_test_tasks 1 --use_gui --debug --num_train_tasks 0 \ --sesame_max_skeletons_optimized 1 --make_failure_videos --video_fps 20 \ --pybullet_camera_height 900 --pybullet_camera_width 900 @@ -17,8 +17,8 @@ from predicators.pybullet_helpers.objects import create_object, update_object from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot from predicators.settings import CFG -from predicators.structs import Action, EnvironmentTask, GroundAtom, Object, \ - Predicate, State, Type +from predicators.structs import Action, EnvironmentTask, GroundAtom, \ + Object, Predicate, State, Type class PyBulletGrowEnv(PyBulletEnv): diff --git a/predicators/ground_truth_models/circuit/nsrts.py b/predicators/ground_truth_models/circuit/nsrts.py index 847011cf8b..be68f8c728 100644 --- a/predicators/ground_truth_models/circuit/nsrts.py +++ b/predicators/ground_truth_models/circuit/nsrts.py @@ -34,37 +34,38 @@ def get_nsrts(env_name: str, types: Dict[str, Type], CircuitClosed = predicates["CircuitClosed"] # Options - PickConnector = options["PickConnector"] + Pick = options["PickWire"] Connect = options["Connect"] nsrts = set() - # PickConnector + # PickWire robot = Variable("?robot", robot_type) - connector = Variable("?connector", wire_type) - parameters = [robot, connector] - option_vars = [robot, connector] - option = PickConnector + wire = Variable("?wire", wire_type) + parameters = [robot, wire] + option_vars = [robot, wire] + option = Pick preconditions = { LiftedAtom(HandEmpty, [robot]), } add_effects = { - LiftedAtom(Holding, [robot, connector]), + LiftedAtom(Holding, [robot, wire]), } delete_effects = { LiftedAtom(HandEmpty, [robot]), } - pick_connector_nsrt = NSRT("PickConnector", parameters, + pick_wire_nsrt = NSRT("PickWire", parameters, preconditions, add_effects, delete_effects, set(), option, option_vars, null_sampler) - nsrts.add(pick_connector_nsrt) + nsrts.add(pick_wire_nsrt) # ConnectFirstWire. Connect first wire to light and battery. + robot = Variable("?robot", robot_type) wire = Variable("?wire", wire_type) light = Variable("?light", light_type) battery = Variable("?battery", battery_type) - parameters = [wire, light, battery] - option_vars = [wire, light, battery] + parameters = [robot, wire, light, battery] + option_vars = [robot, wire, light, battery] option = Connect preconditions = { LiftedAtom(Holding, [robot, wire]), @@ -72,6 +73,7 @@ def get_nsrts(env_name: str, types: Dict[str, Type], # close enough } add_effects = { + LiftedAtom(HandEmpty, [robot]), LiftedAtom(ConnectedToLight, [wire, light]), LiftedAtom(ConnectedToBattery, [wire, battery]), } @@ -85,18 +87,20 @@ def get_nsrts(env_name: str, types: Dict[str, Type], nsrts.add(connect_first_wire_nsrt) # hacky: connect second wire to light and power + robot = Variable("?robot", robot_type) wire = Variable("?wire", wire_type) light = Variable("?light", light_type) battery = Variable("?battery", battery_type) - parameters = [wire, light, battery] - option_vars = [wire, light, battery] + parameters = [robot, wire, light, battery] + option_vars = [robot, wire, light, battery] option = Connect preconditions = { LiftedAtom(Holding, [robot, wire]), } add_effects = { - LiftedAtom(CircuitClosed, []), - LiftedAtom(LightOn, [light]), + LiftedAtom(HandEmpty, [robot]), + LiftedAtom(CircuitClosed, [light, battery]), + # LiftedAtom(LightOn, [light]), } delete_effects = { LiftedAtom(Holding, [robot, wire]), diff --git a/predicators/ground_truth_models/circuit/options.py b/predicators/ground_truth_models/circuit/options.py index bdb0b6b3ae..088cfa7ea9 100644 --- a/predicators/ground_truth_models/circuit/options.py +++ b/predicators/ground_truth_models/circuit/options.py @@ -1,5 +1,6 @@ """Ground-truth options for the coffee environment.""" +import logging from functools import lru_cache from typing import ClassVar, Dict, Sequence, Set from typing import Type as TypingType @@ -8,13 +9,14 @@ from gym.spaces import Box from predicators.envs.pybullet_coffee import PyBulletCoffeeEnv -from predicators.envs.pybullet_grow import PyBulletGrowEnv +from predicators.envs.pybullet_circuit import PyBulletCircuitEnv from predicators.ground_truth_models import GroundTruthOptionFactory from predicators.ground_truth_models.coffee.options import \ PyBulletCoffeeGroundTruthOptionFactory from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot from predicators.structs import Action, Array, Object, ParameterizedOption, \ ParameterizedPolicy, Predicate, State, Type +from predicators import utils @lru_cache @@ -27,10 +29,9 @@ def _get_pybullet_robot() -> SingleArmPyBulletRobot: class PyBulletCircuitGroundTruthOptionFactory(GroundTruthOptionFactory): """Ground-truth options for the grow environment.""" - env_cls: ClassVar[TypingType[PyBulletGrowEnv]] = PyBulletGrowEnv + env_cls: ClassVar[TypingType[PyBulletCircuitEnv]] = PyBulletCircuitEnv pick_policy_tol: ClassVar[float] = 1e-3 - pour_policy_tol: ClassVar[float] = 1e-3 - _finger_action_nudge_magnitude: ClassVar[float] = 1e-3 + place_policy_tol: ClassVar[float] = 1e-4 @classmethod def get_env_names(cls) -> Set[str]: @@ -40,5 +41,224 @@ def get_env_names(cls) -> Set[str]: def get_options(cls, env_name: str, types: Dict[str, Type], predicates: Dict[str, Predicate], action_space: Box) -> Set[ParameterizedOption]: + # Types + robot_type = types["robot"] + wire_type = types["wire"] + light_type = types["light"] + battery_type = types["battery"] - return set() \ No newline at end of file + # Predicates + Holding = predicates["Holding"] + ConnectedToLight = predicates["ConnectedToLight"] + ConnectedToBattery = predicates["ConnectedToBattery"] + + options = set() + # PickWire + def _PickWire_terminal(state: State, memory: Dict, + objects: Sequence[Object], + params: Array) -> bool: + del memory, params + robot, wire = objects + holds = Holding.holds(state, [robot, wire]) + return holds + + PickWire = ParameterizedOption( + "PickWire", + types=[robot_type, wire_type], + params_space=Box(0, 1, (0, )), + policy=cls._create_pick_wire_policy(), + initiable=lambda s, m, o, p: True, + terminal=_PickWire_terminal) + + def _RestoreForPickWire_terminal(state: State, memory: Dict, + objects: Sequence[Object], + params: Array) -> bool: + del memory, params + robot, wire = objects + robot_pos = (state.get(robot, "x"), state.get(robot, "y"), + state.get(robot, "z")) + wx = state.get(wire, "x") + wy = state.get(wire, "y") + target_pos = (wx, wy, cls.env_cls.robot_init_z) + return bool(np.allclose(robot_pos, target_pos, atol=1e-1)) + + RestoreForPickWire = ParameterizedOption( + "RestoreForPickWire", + types=[robot_type, wire_type], + params_space=Box(0, 1, (0, )), + policy=cls._create_move_to_above_position_policy(), + initiable=lambda s, m, o, p: True, + terminal=_RestoreForPickWire_terminal) + PickWire = utils.LinearChainParameterizedOption( + "PickWire", [ + # RestoreForPickWire, + PickWire]) + options.add(PickWire) + + # Connect + def _Connect_terminal(state: State, memory: Dict, + objects: Sequence[Object], + params: Array) -> bool: + del memory, params + robot, wire, light, battery = objects + # connected_to_light = ConnectedToLight.holds(state, [wire, light]) + # connected_to_battery = ConnectedToBattery.holds(state, [wire, + # battery]) + # return connected_to_light and connected_to_battery + return not Holding.holds(state, [robot, wire]) + + Connect = ParameterizedOption( + "Connect", + types=[robot_type, wire_type, light_type, battery_type], + params_space=Box(0, 1, (0, )), + policy=cls._create_connect_policy(), + initiable=lambda s, m, o, p: True, + terminal=_Connect_terminal) + + def _RestoreForConnect_terminal(state: State, memory: Dict, + objects: Sequence[Object], + params: Array) -> bool: + del memory, params + robot, wire, light, battery = objects + robot_pos = (state.get(robot, "x"), state.get(robot, "y"), + state.get(robot, "z")) + robot_init_pos = (robot_pos[0], + robot_pos[1], + cls.env_cls.robot_init_z) + return bool(np.allclose(robot_pos, robot_init_pos, atol=1e-1)) + + RestoreForConnect = ParameterizedOption( + "RestRestoreForConnecoreForPickWire", + types=[robot_type, wire_type, light_type, battery_type], + params_space=Box(0, 1, (0, )), + policy=cls._create_move_to_above_position_policy(), + initiable=lambda s, m, o, p: True, + terminal=_RestoreForConnect_terminal) + + Connect = utils.LinearChainParameterizedOption( + "Connect", [Connect, + RestoreForConnect + ]) + options.add(Connect) + + return options + @classmethod + def _create_move_to_above_position_policy(cls) -> ParameterizedPolicy: + + def policy(state: State, memory: Dict, objects: Sequence[Object], + params: Array) -> Action: + # This policy moves the robot to the initial position + del memory, params + robot, wire = objects[:2] + robot_pos = (state.get(robot, "x"), + state.get(robot, "y"), + state.get(robot, "z")) + wrot = state.get(wire, "rot") + rrot = state.get(robot, "wrist") + dwrist = wrot - rrot + target_pos = (robot_pos[0], robot_pos[1], + cls.env_cls.robot_init_z) + return PyBulletCoffeeGroundTruthOptionFactory._get_move_action(state, + target_pos, + robot_pos, + finger_status="open", + dwrist=dwrist) + + return policy + @classmethod + def _create_pick_wire_policy(cls) -> ParameterizedPolicy: + def pick_wire_policy(state: State, memory: Dict, + objects: Sequence[Object], + params: Array) -> Action: + """Pick wire by 1) Rotate, 2) Pick up + """ + del memory, params + robot, wire = objects + wx = state.get(wire, "x") + wy = state.get(wire, "y") + wz = state.get(wire, "z") + wr = state.get(wire, "rot") + wpos = (wx, wy, wz) + rx = state.get(robot, "x") + ry = state.get(robot, "y") + rz = state.get(robot, "z") + rr = state.get(robot, "wrist") + rpos = (rx, ry, rz) + dwrist = wr - rr + dwrist = np.clip(dwrist, -cls.env_cls.max_angular_vel, + cls.env_cls.max_angular_vel) + + way_point = (wpos[0], wpos[1], rz) #cls.env_cls.robot_init_z) + sq_dist_to_way_point = np.sum((np.array(rpos) - + np.array(way_point))**2) + if sq_dist_to_way_point > cls.pick_policy_tol: + return PyBulletCoffeeGroundTruthOptionFactory._get_move_action( + state, + way_point, + rpos, + finger_status="open", + dwrist=dwrist) + + sq_dist = np.sum((np.array(wpos) - np.array(rpos))**2) + if sq_dist < cls.pick_policy_tol: + return PyBulletCoffeeGroundTruthOptionFactory._get_pick_action( + state) + else: + return PyBulletCoffeeGroundTruthOptionFactory._get_move_action( + state, + wpos, + rpos, + finger_status="open", + dwrist=dwrist) + return pick_wire_policy + + @classmethod + def _create_connect_policy(cls) -> ParameterizedPolicy: + def connect_policy(state: State, memory: Dict, + objects: Sequence[Object], + params: Array) -> Action: + """Connect wire to light and battery by 1) Rotate, 2) Connect + """ + del memory, params + robot, wire, light, battery = objects + wx = state.get(wire, "x") + wy = state.get(wire, "y") + wz = state.get(wire, "z") + wr = state.get(wire, "rot") + target_rot = 0 + rx = state.get(robot, "x") + ry = state.get(robot, "y") + rz = state.get(robot, "z") + rr = state.get(robot, "wrist") + cur_pos = (rx, ry, rz) + # cur_pos = (wx, wy, wz) + dwrist = target_rot - rr + + lx = state.get(light, "x") + ly = state.get(light, "y") + lz = state.get(light, "z") + bx = state.get(battery, "x") + + at_top = 1 if (wy > ly) else -1 + target_x = (lx + bx) / 2 + target_y = ly + at_top * (cls.env_cls.bulb_snap_length / 2 + + cls.env_cls.snap_width / 2 - 0.01) + target_pos = (target_x, target_y, rz) + # logging.debug(f"current pos: {cur_pos}") + # logging.debug(f"target pos: {target_pos}") + # logging.debug(f"current wrist: {rr}") + # logging.debug(f"target wrist: {wr}") + + sq_dist = np.sum((np.array(cur_pos) - np.array(target_pos))**2) + if sq_dist < cls.place_policy_tol: + # logging.debug("Place") + return PyBulletCoffeeGroundTruthOptionFactory.\ + _get_place_action(state) + + return PyBulletCoffeeGroundTruthOptionFactory._get_move_action( + state, + target_pos, + cur_pos, + finger_status="closed", + dwrist=dwrist) + return connect_policy \ No newline at end of file