Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
yichao-liang committed Jan 13, 2025
1 parent 1a1b02f commit 74c2116
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 348 deletions.
222 changes: 109 additions & 113 deletions predicators/envs/pybullet_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ class PyBulletBalanceEnv(PyBulletEnv):
_table_mid_half_extents = [0.1, _table_mid_w / 2, _table_height / 2]

# Plate
_plate_height: ClassVar[float] = 0.01
_plate_z = _table_height - _plate_height
_plate_height: ClassVar[float] = 0.02
_plate_z = _table_height - _plate_height * 3
_plate1_pose: ClassVar[Pose3D] = (_table_x, _table2_y - _table_mid_w / 2 -
_table_side_w / 2 - _table_gap, _plate_z)
_plate3_pose: ClassVar[Pose3D] = (_table_x, _table2_y + _table_mid_w / 2 +
Expand Down Expand Up @@ -83,7 +83,7 @@ class PyBulletBalanceEnv(PyBulletEnv):

_camera_target: ClassVar[Pose3D] = (1.65, 0.75, 0.52)

_block_mass: ClassVar[float] = 0.5
_block_mass: ClassVar[float] = 1
_block_size = CFG.balance_block_size
_num_blocks_train = CFG.balance_num_blocks_train
_num_blocks_test = CFG.balance_num_blocks_test
Expand Down Expand Up @@ -222,7 +222,7 @@ def initialize_pybullet(
plate3_id = create_pybullet_block(
(.9, .9, .9, 1),
cls._plate_half_extents,
0.0,
1.0,
1.0,
cls._plate3_pose,
cls._table_orientation,
Expand All @@ -232,7 +232,7 @@ def initialize_pybullet(
plate1_id = create_pybullet_block(
(.9, .9, .9, 1),
cls._plate_half_extents,
0.0,
1.0,
1.0,
cls._plate1_pose,
cls._table_orientation,
Expand All @@ -243,7 +243,7 @@ def initialize_pybullet(
beam1_id = create_pybullet_block(
(0.9, 0.9, 0.9, 1),
cls._beam_half_extents,
0.0,
1.0,
1.0,
cls._beam1_pose,
cls._table_orientation,
Expand All @@ -252,19 +252,21 @@ def initialize_pybullet(
beam2_id = create_pybullet_block(
(0.9, 0.9, 0.9, 1),
cls._beam_half_extents,
0.0,
1.0,
1.0,
cls._beam2_pose,
cls._table_orientation,
physics_client_id,
)
bodies["beam_ids"] = [beam1_id, beam2_id]
cls.fix_plates_and_beams_in_place(physics_client_id, table2_id, plate1_id,
plate3_id, beam1_id, beam2_id)


button_id = create_pybullet_block(
cls._button_color_off,
[cls._button_radius] * 3,
0.0,
1.0,
1.0,
(cls.button_x, cls.button_y, cls.button_z),
cls._table_orientation,
Expand Down Expand Up @@ -293,34 +295,34 @@ def initialize_pybullet(

return physics_client_id, pybullet_robot, bodies

# @staticmethod
# def fix_plates_and_beams_in_place(physics_client_id, table_id, plate1_id,
# plate3_id, beam1_id, beam2_id):
# # Doesn't work for some reason
# for child_id in [plate1_id, plate3_id, beam1_id, beam2_id]:
# parent_pos, parent_orn = p.getBasePositionAndOrientation(table_id,
# physicsClientId=physics_client_id)
# child_pos, child_orn = p.getBasePositionAndOrientation(child_id,
# physicsClientId=physics_client_id)
# rel_pos, rel_orn = p.multiplyTransforms(
# p.invertTransform(parent_pos, parent_orn)[0],
# p.invertTransform(parent_pos, parent_orn)[1],
# child_pos,
# child_orn
# )
# p.createConstraint(
# parentBodyUniqueId=table_id,
# parentLinkIndex=-1,
# childBodyUniqueId=child_id,
# childLinkIndex=-1,
# jointType=p.JOINT_FIXED,
# jointAxis=(0, 0, 0),
# parentFramePosition=rel_pos,
# parentFrameOrientation=rel_orn,
# childFramePosition=(0, 0, 0),
# childFrameOrientation=(0, 0, 0),
# physicsClientId=physics_client_id
# )
@staticmethod
def fix_plates_and_beams_in_place(physics_client_id, table_id, plate1_id,
plate3_id, beam1_id, beam2_id):
# Doesn't work for some reason
for child_id in [plate1_id, plate3_id, beam1_id, beam2_id]:
parent_pos, parent_orn = p.getBasePositionAndOrientation(table_id,
physicsClientId=physics_client_id)
child_pos, child_orn = p.getBasePositionAndOrientation(child_id,
physicsClientId=physics_client_id)
rel_pos, rel_orn = p.multiplyTransforms(
p.invertTransform(parent_pos, parent_orn)[0],
p.invertTransform(parent_pos, parent_orn)[1],
child_pos,
child_orn
)
p.createConstraint(
parentBodyUniqueId=table_id,
parentLinkIndex=-1,
childBodyUniqueId=child_id,
childLinkIndex=-1,
jointType=p.JOINT_FIXED,
jointAxis=(0, 0, 0),
parentFramePosition=rel_pos,
parentFrameOrientation=rel_orn,
childFramePosition=(0, 0, 0),
childFrameOrientation=(0, 0, 0),
physicsClientId=physics_client_id
)

def _store_pybullet_bodies(self, pybullet_bodies: Dict[str, Any]) -> None:
self._plate1.id = pybullet_bodies["table_ids"][0]
Expand Down Expand Up @@ -400,6 +402,9 @@ def step(self, action: Action, render_obs: bool = False) -> State:
state = super().step(action, render_obs=render_obs)

self._update_balance_beam(state)
self.fix_plates_and_beams_in_place(self._physics_client_id, self._table_id,
self._plate1.id, self._plate3.id,
self._beam_ids[0], self._beam_ids[1])

# Turn machine on
if self._PressingButton_holds(state, [self._robot, self._machine]):
Expand Down Expand Up @@ -481,6 +486,9 @@ def _reset_state(self, state: State) -> None:
self._prev_diff = 0
# Also do one beam update to make sure the initial positions match
self._update_balance_beam(state)
self.fix_plates_and_beams_in_place(self._physics_client_id, self._table_id,
self._plate1.id, self._plate3.id,
self._beam_ids[0], self._beam_ids[1])

# Update the button color
if self._MachineOn_holds(state, [self._machine, self._robot]):
Expand All @@ -503,95 +511,83 @@ def _reset_state(self, state: State) -> None:

def _update_balance_beam(self, state: State) -> None:
"""Shift the plates, beams, *and blocks on them* to simulate a
balance."""

# Count how many blocks are on each plate by comparing x to midpoint_x.
left_count = 0
right_count = 0
midpoint_y = self._table2_y

block_objs = state.get_objects(self._block_type)
balance, ensuring rising sides move blocks first then plate,
and dropping sides move plate first then blocks."""
left_count = self.count_num_blocks(state, self._plate1)
right_count = self.count_num_blocks(state, self._plate3)

diff = left_count - right_count
if diff == self._prev_diff:
return # No change in distribution, no need to reset positions
return

# Compute plate/beam shifts
shift_per_block = 0.01
shift_amount = diff * shift_per_block

# Update blocks
block_id_map = {obj: obj.id for obj in block_objs} # Object -> int

for block_obj in block_objs:
# Skip out-of-view blocks
old_bz = state.get(block_obj, "pose_z")
if old_bz < 0 or self._held_obj_id == block_obj.id:
continue
by = state.get(block_obj, "pose_y")
block_id = block_id_map[block_obj]

# If block is left of midpoint => shift with plate1
# If block is right of midpoint => shift with plate3
if by < midpoint_y:
new_bz = old_bz - shift_amount
shift_amount = abs(diff) * shift_per_block
block_objs = state.get_objects(self._block_type)
left_dropping = diff > 0
right_dropping = diff < 0

def shift_blocks(is_left: bool, dropping: bool):
"""Shift blocks for one side, dropping or rising."""
sign = -1 if dropping else 1
midpoint_y = self._table2_y
for block_obj in block_objs:
# Skip out-of-view or held
if state.get(block_obj, "pose_z") < 0 or \
self._held_obj_id == block_obj.id:
continue
by = state.get(block_obj, "pose_y")
belongs_to_side = (by < midpoint_y) if is_left else (by > midpoint_y)
if belongs_to_side:
old_z = state.get(block_obj, "pose_z")
padding = 0
new_z = old_z + (sign * shift_amount) + (sign * padding)
block_pos, block_orn = p.getBasePositionAndOrientation(
block_obj.id, physicsClientId=self._physics_client_id)
p.resetBasePositionAndOrientation(
block_obj.id, [block_pos[0], block_pos[1], new_z],
block_orn, physicsClientId=self._physics_client_id)

def shift_plate(is_left: bool, dropping: bool):
"""Shift plate & beam, dropping or rising."""
sign = -1 if dropping else 1
if is_left:
plate_id, beam_id = self._plate1.id, self._beam_ids[0]
base_plate_z, base_beam_z = self._plate1_pose[2], self._beam1_pose[2]
else:
new_bz = old_bz + shift_amount
# logging.debug(f"Current holding block: {self._held_obj_id}, shifting block {block_obj.id}")
plate_id, beam_id = self._plate3.id, self._beam_ids[1]
base_plate_z, base_beam_z = self._plate3_pose[2], self._beam2_pose[2]

# Update in PyBullet
block_pos, block_orn = p.getBasePositionAndOrientation(
block_id, physicsClientId=self._physics_client_id)
new_block_pos = [block_pos[0], block_pos[1], new_bz]
p.resetBasePositionAndOrientation(
block_id,
new_block_pos,
block_orn,
physicsClientId=self._physics_client_id)

# Update plates
new_plate1_z = self._plate1_pose[2] - shift_amount
new_beam1_z = self._beam1_pose[2] - shift_amount
new_plate3_z = self._plate3_pose[2] + shift_amount
new_beam2_z = self._beam2_pose[2] + shift_amount

# Reset the base positions of the plates and beams
plate1_id = self._plate1.id
beam1_id = self._beam_ids[0]
plate3_id = self._plate3.id
beam2_id = self._beam_ids[1]
new_plate_z = base_plate_z + (sign * shift_amount)
new_beam_z = base_beam_z + (sign * shift_amount)

plate1_pos, plate1_orn = p.getBasePositionAndOrientation(
plate1_id, physicsClientId=self._physics_client_id)
p.resetBasePositionAndOrientation(
plate1_id, [plate1_pos[0], plate1_pos[1], new_plate1_z],
plate1_orn,
physicsClientId=self._physics_client_id)

beam1_pos, beam1_orn = p.getBasePositionAndOrientation(
beam1_id, physicsClientId=self._physics_client_id)
p.resetBasePositionAndOrientation(
beam1_id, [beam1_pos[0], beam1_pos[1], new_beam1_z],
beam1_orn,
physicsClientId=self._physics_client_id)
plate_pos, plate_orn = p.getBasePositionAndOrientation(
plate_id, physicsClientId=self._physics_client_id)
p.resetBasePositionAndOrientation(
plate_id, [plate_pos[0], plate_pos[1], new_plate_z],
plate_orn, physicsClientId=self._physics_client_id)

plate3_pos, plate3_orn = p.getBasePositionAndOrientation(
plate3_id, physicsClientId=self._physics_client_id)
p.resetBasePositionAndOrientation(
plate3_id, [plate3_pos[0], plate3_pos[1], new_plate3_z],
plate3_orn,
physicsClientId=self._physics_client_id)
beam_pos, beam_orn = p.getBasePositionAndOrientation(
beam_id, physicsClientId=self._physics_client_id)
p.resetBasePositionAndOrientation(
beam_id, [beam_pos[0], beam_pos[1], new_beam_z],
beam_orn, physicsClientId=self._physics_client_id)

# Left side update
if left_dropping:
# Drop left plate
shift_plate(True, True)
# Drop left blocks
# Rise right blocks
shift_blocks(False, False)
# Rise right plate
shift_plate(False, False)
else:
shift_blocks(True, False)
shift_plate(True, False)
shift_plate(False, True)
shift_blocks(False, True)

beam2_pos, beam2_orn = p.getBasePositionAndOrientation(
beam2_id, physicsClientId=self._physics_client_id)
p.resetBasePositionAndOrientation(
beam2_id, [beam2_pos[0], beam2_pos[1], new_beam2_z],
beam2_orn,
physicsClientId=self._physics_client_id)
# Right side update

# Record the new difference
self._prev_diff = diff


Expand Down
4 changes: 2 additions & 2 deletions predicators/envs/pybullet_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def step(self, action: Action, render_obs: bool = False) -> Observation:
if self._held_constraint_id is None and self._fingers_closing(action):
# Detect if an object is held. If so, create a grasp constraint.
self._held_obj_id = self._detect_held_object()
logging.debug(f"Detected held object: {self._held_obj_id}")
# logging.debug(f"Detected held object: {self._held_obj_id}")
if self._held_obj_id is not None:
self._create_grasp_constraint()

Expand Down Expand Up @@ -496,7 +496,7 @@ def _detect_held_object(self) -> Optional[int]:
# A perfect score here is 1.0 (normals are unit vectors).
contact_normal = point[7]
score = expected_normal.dot(contact_normal)
logging.debug(f"With obj {obj_id}, score: {score}")
# logging.debug(f"With obj {obj_id}, score: {score}")
assert -1.0 <= score <= 1.0

# Take absolute as object/gripper could be rotated 180
Expand Down
5 changes: 2 additions & 3 deletions predicators/ground_truth_models/balance/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
"""Ground-truth models for Balance environment and variants."""

from .nsrts import BalanceGroundTruthNSRTFactory
from .options import BalanceGroundTruthOptionFactory, \
PyBulletBalanceGroundTruthOptionFactory
from .options import PyBulletBalanceGroundTruthOptionFactory

__all__ = [
"BalanceGroundTruthNSRTFactory", "BalanceGroundTruthOptionFactory",
"BalanceGroundTruthNSRTFactory",
"PyBulletBalanceGroundTruthOptionFactory"
]
Loading

0 comments on commit 74c2116

Please sign in to comment.