diff --git a/predicators/envs/blocks.py b/predicators/envs/blocks.py index bf22ae7ff..99ff3fde1 100644 --- a/predicators/envs/blocks.py +++ b/predicators/envs/blocks.py @@ -29,7 +29,7 @@ class BlocksEnv(BaseEnv): """Blocks domain.""" # Parameters that aren't important enough to need to clog up settings.py - table_height: ClassVar[float] = 0.2 + table_height: ClassVar[float] = 0.4 # The table x bounds are (1.1, 1.6), but the workspace is smaller. # Make it narrow enough that blocks can be only horizontally arranged. # Note that these boundaries are for the block positions, and that a diff --git a/predicators/envs/pybullet_balance.py b/predicators/envs/pybullet_balance.py index 1ec3f67fc..a4c76972f 100644 --- a/predicators/envs/pybullet_balance.py +++ b/predicators/envs/pybullet_balance.py @@ -1,3 +1,13 @@ +""" +Making a demo video: +python predicators/main.py --approach oracle --env pybullet_balance --seed 1 \ +--num_test_tasks 1 --use_gui --debug --num_train_tasks 0 \ +--make_failure_videos --video_fps 20 \ +--pybullet_camera_height 900 --pybullet_camera_width 900 --make_test_videos \ +--sesame_task_planning_heuristic "goal_count" \ +--excluded_predicates "Balanced,OnPlate" --sesame_max_skeletons_optimized 100 \ +--sesame_check_expected_atoms False +""" import logging from pathlib import Path from typing import Any, ClassVar, Dict, List, Optional, Sequence, Set, Tuple, \ @@ -23,11 +33,11 @@ class PyBulletBalanceEnv(PyBulletEnv): # Table parameters. _table_height: ClassVar[float] = 0.4 _table2_pose: ClassVar[Pose3D] = (1.35, 0.75, _table_height/2) + _table_x, _table2_y, _table_z = _table2_pose _table_orientation: ClassVar[Quaternion] = (0., 0., 0., 1.) _table_mid_w = 0.1 _table_side_w = 0.3 _table_gap = 0.05 - _table_x, _table2_y, _table_z = 1.35, 0.75, _table_height _table_mid_half_extents = [0.1, _table_mid_w / 2, _table_height / 2] # Plate @@ -53,7 +63,7 @@ class PyBulletBalanceEnv(PyBulletEnv): _button_radius = 0.04 _button_color_off = [1, 0, 0, 1] _button_color_on = [0, 1, 0, 1] - button_x, button_y, button_z = _table_x, _table2_y, _table_z + button_x, button_y, button_z = _table_x, _table2_y, _table_height button_press_threshold = 1e-3 # Workspace parameters @@ -76,7 +86,7 @@ class PyBulletBalanceEnv(PyBulletEnv): _camera_target: ClassVar[Pose3D] = (1.65, 0.75, 0.52) - _obj_mass: ClassVar[float] = 0.05 + _block_mass: ClassVar[float] = 0.5 _block_size = CFG.balance_block_size _num_blocks_train = CFG.balance_num_blocks_train _num_blocks_test = CFG.balance_num_blocks_test @@ -376,7 +386,7 @@ def initialize_pybullet( half_extents = (block_size / 2.0, block_size / 2.0, block_size / 2.0) block_ids.append( - create_pybullet_block(color, half_extents, cls._obj_mass, + create_pybullet_block(color, half_extents, cls._block_mass, cls._obj_friction, physics_client_id=physics_client_id)) bodies["block_ids"] = block_ids @@ -492,12 +502,11 @@ def _update_balance_beam(self, state: State) -> None: # Count how many blocks are on each plate by comparing x to midpoint_x. left_count = 0 right_count = 0 - midpoint_x = 1.1 + midpoint_y = self._table2_y block_objs = state.get_objects(self._block_type) left_count = self.count_num_blocks(state, self._plate1) right_count = self.count_num_blocks(state, self._plate3) - logging.debug(f"Left count: {left_count}, Right count: {right_count}") diff = left_count - right_count if diff == self._prev_diff: @@ -507,6 +516,36 @@ def _update_balance_beam(self, state: State) -> None: shift_per_block = 0.01 shift_amount = diff * shift_per_block + # Update blocks + block_id_map = {obj: obj.id for obj in block_objs} # Object -> int + + for block_obj in block_objs: + # Skip out-of-view blocks + old_bz = state.get(block_obj, "pose_z") + if old_bz < 0 or self._held_obj_id == block_obj.id: + continue + by = state.get(block_obj, "pose_y") + block_id = block_id_map[block_obj] + + # If block is left of midpoint => shift with plate1 + # If block is right of midpoint => shift with plate3 + if by < midpoint_y: + new_bz = old_bz - shift_amount + else: + new_bz = old_bz + shift_amount + # logging.debug(f"Current holding block: {self._held_obj_id}, shifting block {block_obj.id}") + + # Update in PyBullet + block_pos, block_orn = p.getBasePositionAndOrientation( + block_id, physicsClientId=self._physics_client_id) + new_block_pos = [block_pos[0], block_pos[1], new_bz] + p.resetBasePositionAndOrientation( + block_id, + new_block_pos, + block_orn, + physicsClientId=self._physics_client_id) + + # Update plates new_plate1_z = self._plate1_pose[2] - shift_amount new_beam1_z = self._beam1_pose[2] - shift_amount new_plate3_z = self._plate3_pose[2] + shift_amount @@ -550,35 +589,6 @@ def _update_balance_beam(self, state: State) -> None: beam2_orn, physicsClientId=self._physics_client_id) - # --- ADDED: Shift the blocks resting on each plate --- - block_id_map = {obj: obj.id for obj in block_objs} # Object -> int - - for block_obj in block_objs: - # Skip out-of-view blocks - old_bz = state.get(block_obj, "pose_z") - if old_bz < 0: - continue - bx = state.get(block_obj, "pose_x") - by = state.get(block_obj, "pose_y") - block_id = block_id_map[block_obj] - - # If block is left of midpoint => shift with plate1 - # If block is right of midpoint => shift with plate3 - if bx < midpoint_x: - new_bz = old_bz - shift_amount - else: - new_bz = old_bz + shift_amount - - # Update in PyBullet - block_pos, block_orn = p.getBasePositionAndOrientation( - block_id, physicsClientId=self._physics_client_id) - new_block_pos = [block_pos[0], block_pos[1], new_bz] - p.resetBasePositionAndOrientation( - block_id, - new_block_pos, - block_orn, - physicsClientId=self._physics_client_id) - # --- ADDED end --- # Record the new difference self._prev_diff = diff @@ -802,7 +812,7 @@ def _DirectlyOnPlate_holds(self, state: State, block, table = objects y = state.get(block, "pose_y") z = state.get(block, "pose_z") - table_z = state.get(table, "pose_z") + table_z = state.get(table, "pose_z") + self._plate_height/2 desired_z = table_z + self._block_size * 0.5 if (state.get(block, "held") < self.held_tol) and \ @@ -857,12 +867,6 @@ def _generate_test_tasks(self) -> List[EnvironmentTask]: possible_num_blocks=self._num_blocks_test, rng=self._test_rng) - - def _make_tasks(self, num_tasks: int, possible_num_blocks: List[int], - rng: np.random.Generator) -> List[EnvironmentTask]: - tasks = super()._make_tasks(num_tasks, possible_num_blocks, rng) - return self._add_pybullet_state_to_tasks(tasks) - def _load_task_from_json(self, json_file: Path) -> EnvironmentTask: task = super()._load_task_from_json(json_file) return self._add_pybullet_state_to_tasks([task])[0] @@ -935,7 +939,8 @@ def _sample_state_from_piles(self, piles: List[List[Object]], pile_i, pile_j = pile_idx x, y = pile_to_xy[pile_i] # Example: 0.2 + 0.045 * 0.5 - z = self._table_height + self._block_size * (0.5 + pile_j) + z = self._plate_z + self._plate_height + \ + self._block_size * (0.5 + pile_j) r, g, b = rng.uniform(size=3) if "clear" in self._block_type.feature_names: # [pose_x, pose_y, pose_z, held, color_r, color_g, color_b, @@ -1016,7 +1021,7 @@ def _table_xy_is_clear(self, x: float, y: float, import time # Make a task - CFG.seed = 0 + CFG.seed = 1 CFG.num_train_tasks = 0 CFG.num_test_tasks = 1 env = PyBulletBalanceEnv(use_gui=True) diff --git a/predicators/envs/pybullet_blocks.py b/predicators/envs/pybullet_blocks.py index 390176df2..8c01bed79 100644 --- a/predicators/envs/pybullet_blocks.py +++ b/predicators/envs/pybullet_blocks.py @@ -22,7 +22,8 @@ class PyBulletBlocksEnv(PyBulletEnv, BlocksEnv): # Parameters that aren't important enough to need to clog up settings.py # Table parameters. - _table_pose: ClassVar[Pose3D] = (1.35, 0.75, 0.0) + table_height: ClassVar[float] = 0.4 + _table_pose: ClassVar[Pose3D] = (1.35, 0.75, table_height/2) _table_orientation: ClassVar[Quaternion] = (0., 0., 0., 1.) def __init__(self, use_gui: bool = True) -> None: diff --git a/predicators/envs/pybullet_cover.py b/predicators/envs/pybullet_cover.py index 7ffdaae2d..bea5df1a1 100644 --- a/predicators/envs/pybullet_cover.py +++ b/predicators/envs/pybullet_cover.py @@ -20,7 +20,8 @@ class PyBulletCoverEnv(PyBulletEnv, CoverEnv): # Parameters that aren't important enough to need to clog up settings.py # Table parameters. - _table_pose: ClassVar[Pose3D] = (1.35, 0.75, 0.0) + _table_height: ClassVar[float] = 0.4 + _table_pose: ClassVar[Pose3D] = (1.35, 0.75, _table_height/2) _table_orientation: ClassVar[Quaternion] = (0., 0., 0., 1.) # Object parameters. @@ -28,7 +29,6 @@ class PyBulletCoverEnv(PyBulletEnv, CoverEnv): _max_obj_width: ClassVar[float] = 0.07 # highest width normalized to this # Dimension and workspace parameters. - _table_height: ClassVar[float] = 0.2 y_lb: ClassVar[float] = 0.4 y_ub: ClassVar[float] = 1.1 robot_init_x: ClassVar[float] = CoverEnv.workspace_x diff --git a/predicators/envs/pybullet_env.py b/predicators/envs/pybullet_env.py index df7b0893e..4bf0c227b 100644 --- a/predicators/envs/pybullet_env.py +++ b/predicators/envs/pybullet_env.py @@ -3,6 +3,7 @@ Contains useful common code. """ +import logging import abc from typing import Any, ClassVar, Dict, List, Optional, Sequence, Tuple, cast @@ -441,6 +442,7 @@ def step(self, action: Action, render_obs: bool = False) -> Observation: if self._held_constraint_id is None and self._fingers_closing(action): # Detect if an object is held. If so, create a grasp constraint. self._held_obj_id = self._detect_held_object() + logging.debug(f"Detected held object: {self._held_obj_id}") if self._held_obj_id is not None: self._create_grasp_constraint() @@ -493,6 +495,7 @@ def _detect_held_object(self) -> Optional[int]: # A perfect score here is 1.0 (normals are unit vectors). contact_normal = point[7] score = expected_normal.dot(contact_normal) + logging.debug(f"With obj {obj_id}, score: {score}") assert -1.0 <= score <= 1.0 # Take absolute as object/gripper could be rotated 180 @@ -570,7 +573,7 @@ def _add_pybullet_state_to_tasks( # observation, would it work without the reset_state? # Attempt 2: First reset it. self._current_observation = init - self._reset_state(init) + # self._reset_state(init) # Cast _current_observation from type State to PybulletState joint_positions = self._pybullet_robot.get_joints() self._current_observation = utils.PyBulletState( diff --git a/predicators/ground_truth_models/balance/options.py b/predicators/ground_truth_models/balance/options.py index 5ff5aaf10..ac85a2853 100644 --- a/predicators/ground_truth_models/balance/options.py +++ b/predicators/ground_truth_models/balance/options.py @@ -355,7 +355,7 @@ def close_fingers_func(state: State, objects: Sequence[Object], cls._create_blocks_move_to_above_block_option( name="MoveEndEffectorToStack", z_func=lambda block_z: - (block_z + block_size + cls._offset_z), + (block_z + block_size * 2), finger_status="closed", pybullet_robot=pybullet_robot, option_types=option_types, @@ -386,7 +386,7 @@ def close_fingers_func(state: State, objects: Sequence[Object], # Move to above the table at the (x, y) where we will place. cls._create_blocks_move_to_above_table_option( name="MoveEndEffectorToPrePutOnPlate", - z=lambda _: cls.env_cls.z_ub, + z=cls.env_cls.z_ub, finger_status="closed", pybullet_robot=pybullet_robot, option_types=option_types, @@ -394,7 +394,7 @@ def close_fingers_func(state: State, objects: Sequence[Object], # Move down to place. cls._create_blocks_move_to_above_table_option( name="MoveEndEffectorToPutOnPlate", - z=lambda _: cls.env_cls.z_ub - 0.2, + z=cls.env_cls.z_ub - 0.2, finger_status="closed", pybullet_robot=pybullet_robot, option_types=option_types, @@ -407,7 +407,7 @@ def close_fingers_func(state: State, objects: Sequence[Object], # Move back up. cls._create_blocks_move_to_above_table_option( name="MoveEndEffectorBackUp", - z=lambda _: cls.env_cls.z_ub, + z=cls.env_cls.z_ub, finger_status="open", pybullet_robot=pybullet_robot, option_types=option_types, diff --git a/predicators/settings.py b/predicators/settings.py index 240e30753..a881c6731 100644 --- a/predicators/settings.py +++ b/predicators/settings.py @@ -173,6 +173,10 @@ class GlobalSettings: "pybullet_blocks": { "fetch": (0.7071, 0.0, -0.7071, 0.0), "panda": (0.7071, 0.7071, 0.0, 0.0), + }, + "pybullet_balance": { + "fetch": (0.7071, 0.0, -0.7071, 0.0), + "panda": (0.7071, 0.7071, 0.0, 0.0), } }) pybullet_ik_validate = True