diff --git a/predicators/envs/pybullet_balance.py b/predicators/envs/pybullet_balance.py index d93e0564e..0fd0107c8 100644 --- a/predicators/envs/pybullet_balance.py +++ b/predicators/envs/pybullet_balance.py @@ -40,8 +40,8 @@ class PyBulletBalanceEnv(PyBulletEnv): _table_mid_half_extents = [0.1, _table_mid_w / 2, _table_height / 2] # Plate - _plate_height: ClassVar[float] = 0.01 - _plate_z = _table_height - _plate_height + _plate_height: ClassVar[float] = 0.02 + _plate_z = _table_height - _plate_height * 3 _plate1_pose: ClassVar[Pose3D] = (_table_x, _table2_y - _table_mid_w / 2 - _table_side_w / 2 - _table_gap, _plate_z) _plate3_pose: ClassVar[Pose3D] = (_table_x, _table2_y + _table_mid_w / 2 + @@ -83,7 +83,7 @@ class PyBulletBalanceEnv(PyBulletEnv): _camera_target: ClassVar[Pose3D] = (1.65, 0.75, 0.52) - _block_mass: ClassVar[float] = 0.5 + _block_mass: ClassVar[float] = 1 _block_size = CFG.balance_block_size _num_blocks_train = CFG.balance_num_blocks_train _num_blocks_test = CFG.balance_num_blocks_test @@ -222,7 +222,7 @@ def initialize_pybullet( plate3_id = create_pybullet_block( (.9, .9, .9, 1), cls._plate_half_extents, - 0.0, + 1.0, 1.0, cls._plate3_pose, cls._table_orientation, @@ -232,7 +232,7 @@ def initialize_pybullet( plate1_id = create_pybullet_block( (.9, .9, .9, 1), cls._plate_half_extents, - 0.0, + 1.0, 1.0, cls._plate1_pose, cls._table_orientation, @@ -243,7 +243,7 @@ def initialize_pybullet( beam1_id = create_pybullet_block( (0.9, 0.9, 0.9, 1), cls._beam_half_extents, - 0.0, + 1.0, 1.0, cls._beam1_pose, cls._table_orientation, @@ -252,19 +252,21 @@ def initialize_pybullet( beam2_id = create_pybullet_block( (0.9, 0.9, 0.9, 1), cls._beam_half_extents, - 0.0, + 1.0, 1.0, cls._beam2_pose, cls._table_orientation, physics_client_id, ) bodies["beam_ids"] = [beam1_id, beam2_id] + cls.fix_plates_and_beams_in_place(physics_client_id, table2_id, plate1_id, + plate3_id, beam1_id, beam2_id) button_id = create_pybullet_block( cls._button_color_off, [cls._button_radius] * 3, - 0.0, + 1.0, 1.0, (cls.button_x, cls.button_y, cls.button_z), cls._table_orientation, @@ -293,34 +295,34 @@ def initialize_pybullet( return physics_client_id, pybullet_robot, bodies - # @staticmethod - # def fix_plates_and_beams_in_place(physics_client_id, table_id, plate1_id, - # plate3_id, beam1_id, beam2_id): - # # Doesn't work for some reason - # for child_id in [plate1_id, plate3_id, beam1_id, beam2_id]: - # parent_pos, parent_orn = p.getBasePositionAndOrientation(table_id, - # physicsClientId=physics_client_id) - # child_pos, child_orn = p.getBasePositionAndOrientation(child_id, - # physicsClientId=physics_client_id) - # rel_pos, rel_orn = p.multiplyTransforms( - # p.invertTransform(parent_pos, parent_orn)[0], - # p.invertTransform(parent_pos, parent_orn)[1], - # child_pos, - # child_orn - # ) - # p.createConstraint( - # parentBodyUniqueId=table_id, - # parentLinkIndex=-1, - # childBodyUniqueId=child_id, - # childLinkIndex=-1, - # jointType=p.JOINT_FIXED, - # jointAxis=(0, 0, 0), - # parentFramePosition=rel_pos, - # parentFrameOrientation=rel_orn, - # childFramePosition=(0, 0, 0), - # childFrameOrientation=(0, 0, 0), - # physicsClientId=physics_client_id - # ) + @staticmethod + def fix_plates_and_beams_in_place(physics_client_id, table_id, plate1_id, + plate3_id, beam1_id, beam2_id): + # Doesn't work for some reason + for child_id in [plate1_id, plate3_id, beam1_id, beam2_id]: + parent_pos, parent_orn = p.getBasePositionAndOrientation(table_id, + physicsClientId=physics_client_id) + child_pos, child_orn = p.getBasePositionAndOrientation(child_id, + physicsClientId=physics_client_id) + rel_pos, rel_orn = p.multiplyTransforms( + p.invertTransform(parent_pos, parent_orn)[0], + p.invertTransform(parent_pos, parent_orn)[1], + child_pos, + child_orn + ) + p.createConstraint( + parentBodyUniqueId=table_id, + parentLinkIndex=-1, + childBodyUniqueId=child_id, + childLinkIndex=-1, + jointType=p.JOINT_FIXED, + jointAxis=(0, 0, 0), + parentFramePosition=rel_pos, + parentFrameOrientation=rel_orn, + childFramePosition=(0, 0, 0), + childFrameOrientation=(0, 0, 0), + physicsClientId=physics_client_id + ) def _store_pybullet_bodies(self, pybullet_bodies: Dict[str, Any]) -> None: self._plate1.id = pybullet_bodies["table_ids"][0] @@ -400,6 +402,9 @@ def step(self, action: Action, render_obs: bool = False) -> State: state = super().step(action, render_obs=render_obs) self._update_balance_beam(state) + self.fix_plates_and_beams_in_place(self._physics_client_id, self._table_id, + self._plate1.id, self._plate3.id, + self._beam_ids[0], self._beam_ids[1]) # Turn machine on if self._PressingButton_holds(state, [self._robot, self._machine]): @@ -481,6 +486,9 @@ def _reset_state(self, state: State) -> None: self._prev_diff = 0 # Also do one beam update to make sure the initial positions match self._update_balance_beam(state) + self.fix_plates_and_beams_in_place(self._physics_client_id, self._table_id, + self._plate1.id, self._plate3.id, + self._beam_ids[0], self._beam_ids[1]) # Update the button color if self._MachineOn_holds(state, [self._machine, self._robot]): @@ -503,95 +511,83 @@ def _reset_state(self, state: State) -> None: def _update_balance_beam(self, state: State) -> None: """Shift the plates, beams, *and blocks on them* to simulate a - balance.""" - - # Count how many blocks are on each plate by comparing x to midpoint_x. - left_count = 0 - right_count = 0 - midpoint_y = self._table2_y - - block_objs = state.get_objects(self._block_type) + balance, ensuring rising sides move blocks first then plate, + and dropping sides move plate first then blocks.""" left_count = self.count_num_blocks(state, self._plate1) right_count = self.count_num_blocks(state, self._plate3) - diff = left_count - right_count if diff == self._prev_diff: - return # No change in distribution, no need to reset positions + return - # Compute plate/beam shifts shift_per_block = 0.01 - shift_amount = diff * shift_per_block - - # Update blocks - block_id_map = {obj: obj.id for obj in block_objs} # Object -> int - - for block_obj in block_objs: - # Skip out-of-view blocks - old_bz = state.get(block_obj, "pose_z") - if old_bz < 0 or self._held_obj_id == block_obj.id: - continue - by = state.get(block_obj, "pose_y") - block_id = block_id_map[block_obj] - - # If block is left of midpoint => shift with plate1 - # If block is right of midpoint => shift with plate3 - if by < midpoint_y: - new_bz = old_bz - shift_amount + shift_amount = abs(diff) * shift_per_block + block_objs = state.get_objects(self._block_type) + left_dropping = diff > 0 + right_dropping = diff < 0 + + def shift_blocks(is_left: bool, dropping: bool): + """Shift blocks for one side, dropping or rising.""" + sign = -1 if dropping else 1 + midpoint_y = self._table2_y + for block_obj in block_objs: + # Skip out-of-view or held + if state.get(block_obj, "pose_z") < 0 or \ + self._held_obj_id == block_obj.id: + continue + by = state.get(block_obj, "pose_y") + belongs_to_side = (by < midpoint_y) if is_left else (by > midpoint_y) + if belongs_to_side: + old_z = state.get(block_obj, "pose_z") + padding = 0 + new_z = old_z + (sign * shift_amount) + (sign * padding) + block_pos, block_orn = p.getBasePositionAndOrientation( + block_obj.id, physicsClientId=self._physics_client_id) + p.resetBasePositionAndOrientation( + block_obj.id, [block_pos[0], block_pos[1], new_z], + block_orn, physicsClientId=self._physics_client_id) + + def shift_plate(is_left: bool, dropping: bool): + """Shift plate & beam, dropping or rising.""" + sign = -1 if dropping else 1 + if is_left: + plate_id, beam_id = self._plate1.id, self._beam_ids[0] + base_plate_z, base_beam_z = self._plate1_pose[2], self._beam1_pose[2] else: - new_bz = old_bz + shift_amount - # logging.debug(f"Current holding block: {self._held_obj_id}, shifting block {block_obj.id}") + plate_id, beam_id = self._plate3.id, self._beam_ids[1] + base_plate_z, base_beam_z = self._plate3_pose[2], self._beam2_pose[2] - # Update in PyBullet - block_pos, block_orn = p.getBasePositionAndOrientation( - block_id, physicsClientId=self._physics_client_id) - new_block_pos = [block_pos[0], block_pos[1], new_bz] - p.resetBasePositionAndOrientation( - block_id, - new_block_pos, - block_orn, - physicsClientId=self._physics_client_id) - - # Update plates - new_plate1_z = self._plate1_pose[2] - shift_amount - new_beam1_z = self._beam1_pose[2] - shift_amount - new_plate3_z = self._plate3_pose[2] + shift_amount - new_beam2_z = self._beam2_pose[2] + shift_amount - - # Reset the base positions of the plates and beams - plate1_id = self._plate1.id - beam1_id = self._beam_ids[0] - plate3_id = self._plate3.id - beam2_id = self._beam_ids[1] + new_plate_z = base_plate_z + (sign * shift_amount) + new_beam_z = base_beam_z + (sign * shift_amount) - plate1_pos, plate1_orn = p.getBasePositionAndOrientation( - plate1_id, physicsClientId=self._physics_client_id) - p.resetBasePositionAndOrientation( - plate1_id, [plate1_pos[0], plate1_pos[1], new_plate1_z], - plate1_orn, - physicsClientId=self._physics_client_id) - - beam1_pos, beam1_orn = p.getBasePositionAndOrientation( - beam1_id, physicsClientId=self._physics_client_id) - p.resetBasePositionAndOrientation( - beam1_id, [beam1_pos[0], beam1_pos[1], new_beam1_z], - beam1_orn, - physicsClientId=self._physics_client_id) + plate_pos, plate_orn = p.getBasePositionAndOrientation( + plate_id, physicsClientId=self._physics_client_id) + p.resetBasePositionAndOrientation( + plate_id, [plate_pos[0], plate_pos[1], new_plate_z], + plate_orn, physicsClientId=self._physics_client_id) - plate3_pos, plate3_orn = p.getBasePositionAndOrientation( - plate3_id, physicsClientId=self._physics_client_id) - p.resetBasePositionAndOrientation( - plate3_id, [plate3_pos[0], plate3_pos[1], new_plate3_z], - plate3_orn, - physicsClientId=self._physics_client_id) + beam_pos, beam_orn = p.getBasePositionAndOrientation( + beam_id, physicsClientId=self._physics_client_id) + p.resetBasePositionAndOrientation( + beam_id, [beam_pos[0], beam_pos[1], new_beam_z], + beam_orn, physicsClientId=self._physics_client_id) + + # Left side update + if left_dropping: + # Drop left plate + shift_plate(True, True) + # Drop left blocks + # Rise right blocks + shift_blocks(False, False) + # Rise right plate + shift_plate(False, False) + else: + shift_blocks(True, False) + shift_plate(True, False) + shift_plate(False, True) + shift_blocks(False, True) - beam2_pos, beam2_orn = p.getBasePositionAndOrientation( - beam2_id, physicsClientId=self._physics_client_id) - p.resetBasePositionAndOrientation( - beam2_id, [beam2_pos[0], beam2_pos[1], new_beam2_z], - beam2_orn, - physicsClientId=self._physics_client_id) + # Right side update - # Record the new difference self._prev_diff = diff diff --git a/predicators/envs/pybullet_env.py b/predicators/envs/pybullet_env.py index d65e60eb4..863201198 100644 --- a/predicators/envs/pybullet_env.py +++ b/predicators/envs/pybullet_env.py @@ -443,7 +443,7 @@ def step(self, action: Action, render_obs: bool = False) -> Observation: if self._held_constraint_id is None and self._fingers_closing(action): # Detect if an object is held. If so, create a grasp constraint. self._held_obj_id = self._detect_held_object() - logging.debug(f"Detected held object: {self._held_obj_id}") + # logging.debug(f"Detected held object: {self._held_obj_id}") if self._held_obj_id is not None: self._create_grasp_constraint() @@ -496,7 +496,7 @@ def _detect_held_object(self) -> Optional[int]: # A perfect score here is 1.0 (normals are unit vectors). contact_normal = point[7] score = expected_normal.dot(contact_normal) - logging.debug(f"With obj {obj_id}, score: {score}") + # logging.debug(f"With obj {obj_id}, score: {score}") assert -1.0 <= score <= 1.0 # Take absolute as object/gripper could be rotated 180 diff --git a/predicators/ground_truth_models/balance/__init__.py b/predicators/ground_truth_models/balance/__init__.py index 70853b006..d884981b5 100644 --- a/predicators/ground_truth_models/balance/__init__.py +++ b/predicators/ground_truth_models/balance/__init__.py @@ -1,10 +1,9 @@ """Ground-truth models for Balance environment and variants.""" from .nsrts import BalanceGroundTruthNSRTFactory -from .options import BalanceGroundTruthOptionFactory, \ - PyBulletBalanceGroundTruthOptionFactory +from .options import PyBulletBalanceGroundTruthOptionFactory __all__ = [ - "BalanceGroundTruthNSRTFactory", "BalanceGroundTruthOptionFactory", + "BalanceGroundTruthNSRTFactory", "PyBulletBalanceGroundTruthOptionFactory" ] diff --git a/predicators/ground_truth_models/balance/options.py b/predicators/ground_truth_models/balance/options.py index d3e7e47b9..baf63a043 100644 --- a/predicators/ground_truth_models/balance/options.py +++ b/predicators/ground_truth_models/balance/options.py @@ -18,239 +18,13 @@ from predicators.structs import Action, Array, Object, ParameterizedOption, \ ParameterizedPolicy, Predicate, State, Type - -class BalanceGroundTruthOptionFactory(GroundTruthOptionFactory): - """Ground-truth options for the (non-pybullet) blocks environment.""" - - env_cls = PyBulletBalanceEnv - - @classmethod - def get_env_names(cls) -> Set[str]: - return {"balance"} - - @classmethod - def get_options(cls, env_name: str, types: Dict[str, Type], - predicates: Dict[str, Predicate], - action_space: Box) -> Set[ParameterizedOption]: - - robot_type = types["robot"] - block_type = types["block"] - machine_type = types["machine"] - plate_type = types["plate"] - block_size = CFG.blocks_block_size - - Holding = predicates['Holding'] - On = predicates['DirectlyOn'] - OnPlate = predicates['DirectlyOnPlate'] - GripperOpen = predicates['GripperOpen'] - MachineOn = predicates['MachineOn'] - # Balanced = predicates['Balanced'] - Balanced = predicates['Balanced'].untransformed_predicate - - # def _Pick_terminal(s: State, m: Dict, o: Sequence[Object], - # p: Array) -> bool: - # robot, block = o - # return Holding.holds(s, [block]) - - Pick = utils.SingletonParameterizedOption( - # variables: [robot, object to pick] - # params: [] - "Pick", - cls._create_pick_policy(action_space), - types=[robot_type, block_type], - # terminal=_Pick_terminal, - ) - - # Probably will need to change the option parameters for this to work - # def _Stack_terminal(s: State, m: Dict, o: Sequence[Object], - # p: Array) -> bool: - # block, otherblock, _ = o - # return On.holds(s, [block, otherblock]) - - Stack = utils.SingletonParameterizedOption( - # variables: [robot, object on which to stack currently-held-object] - # params: [] - "Stack", - cls._create_stack_policy(action_space, block_size), - types=[robot_type, block_type], - # types = [block_type, block_type, robot_type], - # terminal=_Stack_terminal, - ) - - # def _PutOnPlate_terminal(s: State, m: Dict, o: Sequence[Object], - # p: Array) -> bool: - # block, _ = o - # return OnPlate.holds(s, [block]) - - PutOnPlate = utils.SingletonParameterizedOption( - # variables: [robot] - # params: [x, y] (normalized coordinates on the table surface) - "PutOnPlate", - cls._create_putonplate_policy(action_space, block_size), - types=[robot_type, plate_type], - # types=[block_type, robot_type], - params_space=Box(0, 1, (2, )), - # terminal=_PutOnPlate_terminal, - ) - - # TurnMachineOn - def _TurnMachineOn_initiable(state: State, memory: Dict, - objects: Sequence[Object], - params: Array) -> bool: - del memory, params # unused - plate1, plate2 = objects - robot = state.get_objects(robot_type)[0] - return GripperOpen.holds(state, [robot]) and\ - Balanced.holds(state, [plate1, plate2]) - - def _TurnMachineOn_terminal(state: State, memory: Dict, - objects: Sequence[Object], - params: Array) -> bool: - del memory, params # unused - machine = state.get_objects(machine_type)[0] - robot = state.get_objects(robot_type)[0] - # _, machine, _, _ = objects - return MachineOn.holds(state, [machine, robot]) - - TurnMachineOn = ParameterizedOption( - "TurnMachineOn", - types=[plate_type, plate_type], - params_space=Box(0, 1, (0, )), - policy=cls._create_turn_machine_on_policy(), - initiable=_TurnMachineOn_initiable, - terminal=_TurnMachineOn_terminal, - annotation="Turn the machine on.") - - return {Pick, Stack, PutOnPlate, TurnMachineOn} - - @classmethod - def _create_turn_machine_on_policy(cls) -> ParameterizedPolicy: - - def policy(state: State, memory: Dict, objects: Sequence[Object], - params: Array) -> Action: - # This policy moves the robot up to be level with the button in the - # z direction and then moves forward in the y direction to press it. - del memory, params # unused - robot = [r for r in state if r.type.name == "robot"][0] - # robot = objects[0] - x = state.get(robot, "pose_x") - y = state.get(robot, "pose_y") - z = state.get(robot, "pose_z") - # robot_pos = (x, y, z) - button_pos = (cls.env_cls.button_x, cls.env_cls.button_y, - cls.env_cls.button_z) - arr = np.r_[button_pos, 1.0].astype(np.float32) - # arr = np.clip(arr, cls.env_cls.action_space.low, - # cls.env_cls.action_space.high) - return Action(arr) - # if (cls.env_cls.button_z - z)**2 < cls.env_cls._button_radius**2: - # # Move directly toward the button. - # return cls._get_move_action(state, button_pos, robot_pos) - # # Move only in the z direction. - # return cls._get_move_action(state, (x, y, cls.env_cls.button_z), - # robot_pos) - - return policy - - @classmethod - def _create_pick_policy(cls, action_space: Box) -> ParameterizedPolicy: - - def policy(state: State, memory: Dict, objects: Sequence[Object], - params: Array) -> Action: - del memory, params # unused - _, block = objects - block_pose = np.array([ - state.get(block, "pose_x"), - state.get(block, "pose_y"), - state.get(block, "pose_z") - ]) - arr = np.r_[block_pose, 0.0].astype(np.float32) - arr = np.clip(arr, action_space.low, action_space.high) - return Action(arr) - - return policy - - @classmethod - def _create_stack_policy(cls, action_space: Box, - block_size: float) -> ParameterizedPolicy: - - def policy(state: State, memory: Dict, objects: Sequence[Object], - params: Array) -> Action: - del memory, params # unused - _, block = objects - # _, block, _ = objects - - block_pose = np.array([ - state.get(block, "pose_x"), - state.get(block, "pose_y"), - state.get(block, "pose_z") - ]) - relative_grasp = np.array([ - 0., - 0., - block_size, - ]) - arr = np.r_[block_pose + relative_grasp, 1.0].astype(np.float32) - arr = np.clip(arr, action_space.low, action_space.high) - return Action(arr) - - return policy - - @classmethod - def _create_putonplate_policy(cls, action_space: Box, - block_size: float) -> ParameterizedPolicy: - - def policy(state: State, memory: Dict, objects: Sequence[Object], - params: Array) -> Action: - del state, memory, objects # unused - # De-normalize parameters to actual table coordinates. - x_norm, y_norm = params - x = BalanceEnv.x_lb + (BalanceEnv.x_ub - BalanceEnv.x_lb) * x_norm - y = BalanceEnv.y_lb + (BalanceEnv.y_ub - BalanceEnv.y_lb) * y_norm - z = BalanceEnv._table_height + 0.5 * block_size - arr = np.array([x, y, z, 1.0], dtype=np.float32) - arr = np.clip(arr, action_space.low, action_space.high) - return Action(arr) - - return policy - - ############################ Utility functions ############################ - - @classmethod - def _get_move_action(cls, - state: State, - target_pos: Tuple[float, float, float], - robot_pos: Tuple[float, float, float], - dtilt: float = 0.0, - dwrist: float = 0.0, - finger_status: str = "open") -> Action: - del state, finger_status # used in PyBullet subclass - # We want to move in this direction. - delta = np.subtract(target_pos, robot_pos) - # But we can only move at most max_position_vel in one step. - # Get the norm full move delta. - pos_norm = float(np.linalg.norm(delta)) - # If the norm is more than max_position_vel, rescale the delta so - # that its norm is max_position_vel. - if pos_norm > cls.env_cls.max_position_vel: - delta = cls.env_cls.max_position_vel * (delta / pos_norm) - pos_norm = cls.env_cls.max_position_vel - # Now normalize so that the action values are between -1 and 1, as - # expected by simulate and the action space. - if pos_norm > 0: - delta = delta / cls.env_cls.max_position_vel - dx, dy, dz = delta - return Action(np.array([dx, dy, dz, 0.0], dtype=np.float32)) - - @lru_cache def _get_pybullet_robot() -> SingleArmPyBulletRobot: _, pybullet_robot, _ = \ PyBulletBalanceEnv.initialize_pybullet(using_gui=False) return pybullet_robot - -class PyBulletBalanceGroundTruthOptionFactory(BalanceGroundTruthOptionFactory): +class PyBulletBalanceGroundTruthOptionFactory(GroundTruthOptionFactory): """Ground-truth options for the pybullet_balance environment.""" env_cls = PyBulletBalanceEnv @@ -355,7 +129,7 @@ def close_fingers_func(state: State, objects: Sequence[Object], # Move down to place. cls._create_blocks_move_to_above_block_option( name="MoveEndEffectorToStack", - z_func=lambda block_z: (block_z + block_size * 2), + z_func=lambda block_z: (block_z + block_size * 1.3), finger_status="closed", pybullet_robot=pybullet_robot, option_types=option_types, @@ -482,7 +256,8 @@ def _get_current_and_target_pose_and_finger_status( pybullet_robot, name, option_types, params_space, _get_current_and_target_pose_and_finger_status, cls._move_to_pose_tol, CFG.pybullet_max_vel_norm, - cls._finger_action_nudge_magnitude) + cls._finger_action_nudge_magnitude, + validate=CFG.pybullet_ik_validate) @classmethod def _create_blocks_move_to_above_table_option( @@ -596,4 +371,5 @@ def _get_move_action(cls, return get_move_end_effector_to_pose_action( pybullet_robot, current_joint_positions, current_pose, target_pose, finger_status, CFG.pybullet_max_vel_norm, - cls._finger_action_nudge_magnitude) + cls._finger_action_nudge_magnitude, + validate=CFG.pybullet_ik_validate)