Skip to content

Commit

Permalink
refactor _get_expected_finger_normals to base pybullet env
Browse files Browse the repository at this point in the history
  • Loading branch information
yichao-liang committed Dec 27, 2024
1 parent f5bda14 commit 87e3749
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 71 deletions.
15 changes: 0 additions & 15 deletions predicators/envs/pybullet_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,21 +527,6 @@ def _load_task_from_json(self, json_file: Path) -> EnvironmentTask:
def _get_object_ids_for_held_check(self) -> List[int]:
return sorted(self._block_id_to_block)

def _get_expected_finger_normals(self) -> Dict[int, Array]:
if CFG.pybullet_robot == "panda":
# gripper rotated 90deg so parallel to x-axis
normal = np.array([1., 0., 0.], dtype=np.float32)
elif CFG.pybullet_robot == "fetch":
# gripper parallel to y-axis
normal = np.array([0., 1., 0.], dtype=np.float32)
else: # pragma: no cover
# Shouldn't happen unless we introduce a new robot.
raise ValueError(f"Unknown robot {CFG.pybullet_robot}")

return {
self._pybullet_robot.left_finger_id: normal,
self._pybullet_robot.right_finger_id: -1 * normal,
}

def _force_grasp_object(self, block: Object) -> None:
block_to_block_id = {b: i for i, b in self._block_id_to_block.items()}
Expand Down
15 changes: 0 additions & 15 deletions predicators/envs/pybullet_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,21 +230,6 @@ def _load_task_from_json(self, json_file: Path) -> EnvironmentTask:
def _get_object_ids_for_held_check(self) -> List[int]:
return sorted(self._block_id_to_block)

def _get_expected_finger_normals(self) -> Dict[int, Array]:
if CFG.pybullet_robot == "panda":
# gripper rotated 90deg so parallel to x-axis
normal = np.array([1., 0., 0.], dtype=np.float32)
elif CFG.pybullet_robot == "fetch":
# gripper parallel to y-axis
normal = np.array([0., 1., 0.], dtype=np.float32)
else: # pragma: no cover
# Shouldn't happen unless we introduce a new robot.
raise ValueError(f"Unknown robot {CFG.pybullet_robot}")

return {
self._pybullet_robot.left_finger_id: normal,
self._pybullet_robot.right_finger_id: -1 * normal,
}

def _force_grasp_object(self, block: Object) -> None:
block_to_block_id = {b: i for i, b in self._block_id_to_block.items()}
Expand Down
13 changes: 0 additions & 13 deletions predicators/envs/pybullet_coffee.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,19 +699,6 @@ def _get_object_ids_for_held_check(self) -> List[int]:
return [self._jug_id, self._plug_id]
return [self._jug_id]

def _get_expected_finger_normals(self) -> Dict[int, Array]:
if CFG.pybullet_robot == "fetch":
# gripper parallel to y-axis
normal = np.array([0., 1., 0.], dtype=np.float32)
else: # pragma: no cover
# Shouldn't happen unless we introduce a new robot.
raise ValueError(f"Unknown robot {CFG.pybullet_robot}")

return {
self._pybullet_robot.left_finger_id: normal,
self._pybullet_robot.right_finger_id: -1 * normal,
}

def _state_to_gripper_orn(self, state: State) -> Quaternion:
wrist = state.get(self._robot, "wrist")
tilt = state.get(self._robot, "tilt")
Expand Down
7 changes: 0 additions & 7 deletions predicators/envs/pybullet_cover.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,6 @@ def _get_state(self) -> State:
def _get_object_ids_for_held_check(self) -> List[int]:
return sorted(self._block_id_to_block)

def _get_expected_finger_normals(self) -> Dict[int, Array]:
# Both fetch and panda have grippers parallel to x-axis
return {
self._pybullet_robot.left_finger_id: np.array([1., 0., 0.]),
self._pybullet_robot.right_finger_id: np.array([-1., 0., 0.]),
}

@classmethod
def get_name(cls) -> str:
return "pybullet_cover"
Expand Down
22 changes: 14 additions & 8 deletions predicators/envs/pybullet_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,21 @@ def _get_object_ids_for_held_check(self) -> List[int]:
held."""
raise NotImplementedError("Override me!")

@abc.abstractmethod
def _get_expected_finger_normals(self) -> Dict[int, Array]:
"""Get the expected finger normals, used in detect_held_object(), as a
mapping from finger link index to a unit-length normal vector.
This is environment-specific because it depends on the end
effector's orientation when grasping.
"""
raise NotImplementedError("Override me!")
if CFG.pybullet_robot == "panda":
# gripper rotated 90deg so parallel to x-axis
normal = np.array([1., 0., 0.], dtype=np.float32)
elif CFG.pybullet_robot == "fetch":
# gripper parallel to y-axis
normal = np.array([0., 1., 0.], dtype=np.float32)
else: # pragma: no cover
# Shouldn't happen unless we introduce a new robot.
raise ValueError(f"Unknown robot {CFG.pybullet_robot}")

return {
self._pybullet_robot.left_finger_id: normal,
self._pybullet_robot.right_finger_id: -1 * normal,
}

@classmethod
def _fingers_state_to_joint(cls, pybullet_robot: SingleArmPyBulletRobot,
Expand Down
13 changes: 0 additions & 13 deletions predicators/envs/pybullet_grow.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,19 +269,6 @@ def _get_object_ids_for_held_check(self) -> List[int]:
assert self._red_jug_id is not None and self._blue_jug_id is not None
return [self._red_jug_id, self._blue_jug_id]

def _get_expected_finger_normals(self) -> Dict[int, Array]:
"""For the default fetch robot in predicators.
We assume a certain orientation where the left_finger_id is in
+y direction, the right_finger_id is in -y.
"""
normal_left = np.array([0., 1., 0.], dtype=np.float32)
normal_right = np.array([0., -1., 0.], dtype=np.float32)
return {
self._pybullet_robot.left_finger_id: normal_left,
self._pybullet_robot.right_finger_id: normal_right,
}

# -------------------------------------------------------------------------
# Setting or updating the environment’s state.

Expand Down

0 comments on commit 87e3749

Please sign in to comment.