Skip to content

Commit

Permalink
can make balance demo; adjust table heights for other pybullet envs
Browse files Browse the repository at this point in the history
  • Loading branch information
yichao-liang committed Jan 12, 2025
1 parent 17fe0ea commit 3ac77dc
Show file tree
Hide file tree
Showing 7 changed files with 66 additions and 53 deletions.
2 changes: 1 addition & 1 deletion predicators/envs/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
class BlocksEnv(BaseEnv):
"""Blocks domain."""
# Parameters that aren't important enough to need to clog up settings.py
table_height: ClassVar[float] = 0.2
table_height: ClassVar[float] = 0.4
# The table x bounds are (1.1, 1.6), but the workspace is smaller.
# Make it narrow enough that blocks can be only horizontally arranged.
# Note that these boundaries are for the block positions, and that a
Expand Down
93 changes: 49 additions & 44 deletions predicators/envs/pybullet_balance.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
"""
Making a demo video:
python predicators/main.py --approach oracle --env pybullet_balance --seed 1 \
--num_test_tasks 1 --use_gui --debug --num_train_tasks 0 \
--make_failure_videos --video_fps 20 \
--pybullet_camera_height 900 --pybullet_camera_width 900 --make_test_videos \
--sesame_task_planning_heuristic "goal_count" \
--excluded_predicates "Balanced,OnPlate" --sesame_max_skeletons_optimized 100 \
--sesame_check_expected_atoms False
"""
import logging
from pathlib import Path
from typing import Any, ClassVar, Dict, List, Optional, Sequence, Set, Tuple, \
Expand All @@ -23,11 +33,11 @@ class PyBulletBalanceEnv(PyBulletEnv):
# Table parameters.
_table_height: ClassVar[float] = 0.4
_table2_pose: ClassVar[Pose3D] = (1.35, 0.75, _table_height/2)
_table_x, _table2_y, _table_z = _table2_pose
_table_orientation: ClassVar[Quaternion] = (0., 0., 0., 1.)
_table_mid_w = 0.1
_table_side_w = 0.3
_table_gap = 0.05
_table_x, _table2_y, _table_z = 1.35, 0.75, _table_height
_table_mid_half_extents = [0.1, _table_mid_w / 2, _table_height / 2]

# Plate
Expand All @@ -53,7 +63,7 @@ class PyBulletBalanceEnv(PyBulletEnv):
_button_radius = 0.04
_button_color_off = [1, 0, 0, 1]
_button_color_on = [0, 1, 0, 1]
button_x, button_y, button_z = _table_x, _table2_y, _table_z
button_x, button_y, button_z = _table_x, _table2_y, _table_height
button_press_threshold = 1e-3

# Workspace parameters
Expand All @@ -76,7 +86,7 @@ class PyBulletBalanceEnv(PyBulletEnv):

_camera_target: ClassVar[Pose3D] = (1.65, 0.75, 0.52)

_obj_mass: ClassVar[float] = 0.05
_block_mass: ClassVar[float] = 0.5
_block_size = CFG.balance_block_size
_num_blocks_train = CFG.balance_num_blocks_train
_num_blocks_test = CFG.balance_num_blocks_test
Expand Down Expand Up @@ -376,7 +386,7 @@ def initialize_pybullet(
half_extents = (block_size / 2.0, block_size / 2.0,
block_size / 2.0)
block_ids.append(
create_pybullet_block(color, half_extents, cls._obj_mass,
create_pybullet_block(color, half_extents, cls._block_mass,
cls._obj_friction,
physics_client_id=physics_client_id))
bodies["block_ids"] = block_ids
Expand Down Expand Up @@ -492,12 +502,11 @@ def _update_balance_beam(self, state: State) -> None:
# Count how many blocks are on each plate by comparing x to midpoint_x.
left_count = 0
right_count = 0
midpoint_x = 1.1
midpoint_y = self._table2_y

block_objs = state.get_objects(self._block_type)
left_count = self.count_num_blocks(state, self._plate1)
right_count = self.count_num_blocks(state, self._plate3)
logging.debug(f"Left count: {left_count}, Right count: {right_count}")

diff = left_count - right_count
if diff == self._prev_diff:
Expand All @@ -507,6 +516,36 @@ def _update_balance_beam(self, state: State) -> None:
shift_per_block = 0.01
shift_amount = diff * shift_per_block

# Update blocks
block_id_map = {obj: obj.id for obj in block_objs} # Object -> int

for block_obj in block_objs:
# Skip out-of-view blocks
old_bz = state.get(block_obj, "pose_z")
if old_bz < 0 or self._held_obj_id == block_obj.id:
continue
by = state.get(block_obj, "pose_y")
block_id = block_id_map[block_obj]

# If block is left of midpoint => shift with plate1
# If block is right of midpoint => shift with plate3
if by < midpoint_y:
new_bz = old_bz - shift_amount
else:
new_bz = old_bz + shift_amount
# logging.debug(f"Current holding block: {self._held_obj_id}, shifting block {block_obj.id}")

# Update in PyBullet
block_pos, block_orn = p.getBasePositionAndOrientation(
block_id, physicsClientId=self._physics_client_id)
new_block_pos = [block_pos[0], block_pos[1], new_bz]
p.resetBasePositionAndOrientation(
block_id,
new_block_pos,
block_orn,
physicsClientId=self._physics_client_id)

# Update plates
new_plate1_z = self._plate1_pose[2] - shift_amount
new_beam1_z = self._beam1_pose[2] - shift_amount
new_plate3_z = self._plate3_pose[2] + shift_amount
Expand Down Expand Up @@ -550,35 +589,6 @@ def _update_balance_beam(self, state: State) -> None:
beam2_orn,
physicsClientId=self._physics_client_id)

# --- ADDED: Shift the blocks resting on each plate ---
block_id_map = {obj: obj.id for obj in block_objs} # Object -> int

for block_obj in block_objs:
# Skip out-of-view blocks
old_bz = state.get(block_obj, "pose_z")
if old_bz < 0:
continue
bx = state.get(block_obj, "pose_x")
by = state.get(block_obj, "pose_y")
block_id = block_id_map[block_obj]

# If block is left of midpoint => shift with plate1
# If block is right of midpoint => shift with plate3
if bx < midpoint_x:
new_bz = old_bz - shift_amount
else:
new_bz = old_bz + shift_amount

# Update in PyBullet
block_pos, block_orn = p.getBasePositionAndOrientation(
block_id, physicsClientId=self._physics_client_id)
new_block_pos = [block_pos[0], block_pos[1], new_bz]
p.resetBasePositionAndOrientation(
block_id,
new_block_pos,
block_orn,
physicsClientId=self._physics_client_id)
# --- ADDED end ---

# Record the new difference
self._prev_diff = diff
Expand Down Expand Up @@ -802,7 +812,7 @@ def _DirectlyOnPlate_holds(self, state: State,
block, table = objects
y = state.get(block, "pose_y")
z = state.get(block, "pose_z")
table_z = state.get(table, "pose_z")
table_z = state.get(table, "pose_z") + self._plate_height/2
desired_z = table_z + self._block_size * 0.5

if (state.get(block, "held") < self.held_tol) and \
Expand Down Expand Up @@ -857,12 +867,6 @@ def _generate_test_tasks(self) -> List[EnvironmentTask]:
possible_num_blocks=self._num_blocks_test,
rng=self._test_rng)


def _make_tasks(self, num_tasks: int, possible_num_blocks: List[int],
rng: np.random.Generator) -> List[EnvironmentTask]:
tasks = super()._make_tasks(num_tasks, possible_num_blocks, rng)
return self._add_pybullet_state_to_tasks(tasks)

def _load_task_from_json(self, json_file: Path) -> EnvironmentTask:
task = super()._load_task_from_json(json_file)
return self._add_pybullet_state_to_tasks([task])[0]
Expand Down Expand Up @@ -935,7 +939,8 @@ def _sample_state_from_piles(self, piles: List[List[Object]],
pile_i, pile_j = pile_idx
x, y = pile_to_xy[pile_i]
# Example: 0.2 + 0.045 * 0.5
z = self._table_height + self._block_size * (0.5 + pile_j)
z = self._plate_z + self._plate_height + \
self._block_size * (0.5 + pile_j)
r, g, b = rng.uniform(size=3)
if "clear" in self._block_type.feature_names:
# [pose_x, pose_y, pose_z, held, color_r, color_g, color_b,
Expand Down Expand Up @@ -1016,7 +1021,7 @@ def _table_xy_is_clear(self, x: float, y: float,
import time

# Make a task
CFG.seed = 0
CFG.seed = 1
CFG.num_train_tasks = 0
CFG.num_test_tasks = 1
env = PyBulletBalanceEnv(use_gui=True)
Expand Down
3 changes: 2 additions & 1 deletion predicators/envs/pybullet_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class PyBulletBlocksEnv(PyBulletEnv, BlocksEnv):
# Parameters that aren't important enough to need to clog up settings.py

# Table parameters.
_table_pose: ClassVar[Pose3D] = (1.35, 0.75, 0.0)
table_height: ClassVar[float] = 0.4
_table_pose: ClassVar[Pose3D] = (1.35, 0.75, table_height/2)
_table_orientation: ClassVar[Quaternion] = (0., 0., 0., 1.)

def __init__(self, use_gui: bool = True) -> None:
Expand Down
4 changes: 2 additions & 2 deletions predicators/envs/pybullet_cover.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@ class PyBulletCoverEnv(PyBulletEnv, CoverEnv):
# Parameters that aren't important enough to need to clog up settings.py

# Table parameters.
_table_pose: ClassVar[Pose3D] = (1.35, 0.75, 0.0)
_table_height: ClassVar[float] = 0.4
_table_pose: ClassVar[Pose3D] = (1.35, 0.75, _table_height/2)
_table_orientation: ClassVar[Quaternion] = (0., 0., 0., 1.)

# Object parameters.
_obj_len_hgt: ClassVar[float] = 0.045
_max_obj_width: ClassVar[float] = 0.07 # highest width normalized to this

# Dimension and workspace parameters.
_table_height: ClassVar[float] = 0.2
y_lb: ClassVar[float] = 0.4
y_ub: ClassVar[float] = 1.1
robot_init_x: ClassVar[float] = CoverEnv.workspace_x
Expand Down
5 changes: 4 additions & 1 deletion predicators/envs/pybullet_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Contains useful common code.
"""

import logging
import abc
from typing import Any, ClassVar, Dict, List, Optional, Sequence, Tuple, cast

Expand Down Expand Up @@ -441,6 +442,7 @@ def step(self, action: Action, render_obs: bool = False) -> Observation:
if self._held_constraint_id is None and self._fingers_closing(action):
# Detect if an object is held. If so, create a grasp constraint.
self._held_obj_id = self._detect_held_object()
logging.debug(f"Detected held object: {self._held_obj_id}")
if self._held_obj_id is not None:
self._create_grasp_constraint()

Expand Down Expand Up @@ -493,6 +495,7 @@ def _detect_held_object(self) -> Optional[int]:
# A perfect score here is 1.0 (normals are unit vectors).
contact_normal = point[7]
score = expected_normal.dot(contact_normal)
logging.debug(f"With obj {obj_id}, score: {score}")
assert -1.0 <= score <= 1.0

# Take absolute as object/gripper could be rotated 180
Expand Down Expand Up @@ -570,7 +573,7 @@ def _add_pybullet_state_to_tasks(
# observation, would it work without the reset_state?
# Attempt 2: First reset it.
self._current_observation = init
self._reset_state(init)
# self._reset_state(init)
# Cast _current_observation from type State to PybulletState
joint_positions = self._pybullet_robot.get_joints()
self._current_observation = utils.PyBulletState(
Expand Down
8 changes: 4 additions & 4 deletions predicators/ground_truth_models/balance/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def close_fingers_func(state: State, objects: Sequence[Object],
cls._create_blocks_move_to_above_block_option(
name="MoveEndEffectorToStack",
z_func=lambda block_z:
(block_z + block_size + cls._offset_z),
(block_z + block_size * 2),
finger_status="closed",
pybullet_robot=pybullet_robot,
option_types=option_types,
Expand Down Expand Up @@ -386,15 +386,15 @@ def close_fingers_func(state: State, objects: Sequence[Object],
# Move to above the table at the (x, y) where we will place.
cls._create_blocks_move_to_above_table_option(
name="MoveEndEffectorToPrePutOnPlate",
z=lambda _: cls.env_cls.z_ub,
z=cls.env_cls.z_ub,
finger_status="closed",
pybullet_robot=pybullet_robot,
option_types=option_types,
params_space=params_space),
# Move down to place.
cls._create_blocks_move_to_above_table_option(
name="MoveEndEffectorToPutOnPlate",
z=lambda _: cls.env_cls.z_ub - 0.2,
z=cls.env_cls.z_ub - 0.2,
finger_status="closed",
pybullet_robot=pybullet_robot,
option_types=option_types,
Expand All @@ -407,7 +407,7 @@ def close_fingers_func(state: State, objects: Sequence[Object],
# Move back up.
cls._create_blocks_move_to_above_table_option(
name="MoveEndEffectorBackUp",
z=lambda _: cls.env_cls.z_ub,
z=cls.env_cls.z_ub,
finger_status="open",
pybullet_robot=pybullet_robot,
option_types=option_types,
Expand Down
4 changes: 4 additions & 0 deletions predicators/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ class GlobalSettings:
"pybullet_blocks": {
"fetch": (0.7071, 0.0, -0.7071, 0.0),
"panda": (0.7071, 0.7071, 0.0, 0.0),
},
"pybullet_balance": {
"fetch": (0.7071, 0.0, -0.7071, 0.0),
"panda": (0.7071, 0.7071, 0.0, 0.0),
}
})
pybullet_ik_validate = True
Expand Down

0 comments on commit 3ac77dc

Please sign in to comment.