Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
yichao-liang committed Jan 11, 2025
1 parent 63d202b commit 17fe0ea
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 55 deletions.
71 changes: 37 additions & 34 deletions predicators/envs/pybullet_balance.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
"""A PyBullet version of Blocks."""

import logging
from pathlib import Path
from typing import Any, ClassVar, Dict, List, Optional, Sequence, Set, Tuple, \
Union

import numpy as np
import pybullet as p
from PIL import Image

from predicators import utils
from predicators.envs.pybullet_env import PyBulletEnv, create_pybullet_block
from predicators.pybullet_helpers.geometry import Pose, Pose3D, Quaternion
from predicators.pybullet_helpers.robots import SingleArmPyBulletRobot
from predicators.settings import CFG
from predicators.structs import Action, Array, EnvironmentTask, NSPredicate, \
Object, Predicate, State, Type
Object, Predicate, State, Type, ConceptPredicate, GroundAtom
from predicators.utils import RawState, VLMQuery


Expand All @@ -24,54 +21,58 @@ class PyBulletBalanceEnv(PyBulletEnv):
# Parameters that aren't important enough to need to clog up settings.py

# Table parameters.
_plate_height: ClassVar[float] = 0.01
_table_height: ClassVar[float] = 0.4
_table2_pose: ClassVar[Pose3D] = (1.35, 0.75, _table_height/2)
_table_orientation: ClassVar[Quaternion] = (0., 0., 0., 1.)
_table_mid_w = 0.1
_table_side_w = 0.3
_table_gap = 0.05
_table_x, _table2_y, _table_z = 1.35, 0.75, _table_height
_table_mid_half_extents = [0.1, _table_mid_w / 2, _table_height / 2]

# Plate
_plate_height: ClassVar[float] = 0.01
_plate_z = _table_height - _plate_height
_plate3_pose: ClassVar[Pose3D] = (_table_x, _table2_y + _table_mid_w / 2 +
_table_side_w / 2 + _table_gap, _plate_z)
_plate1_pose: ClassVar[Pose3D] = (_table_x, _table2_y - _table_mid_w / 2 -
_table_side_w / 2 - _table_gap, _plate_z)
_plate1_pose: ClassVar[Pose3D] = (_table_x,
_table2_y - _table_mid_w / 2 - _table_side_w / 2 - _table_gap,
_plate_z)
_plate3_pose: ClassVar[Pose3D] = (_table_x,
_table2_y + _table_mid_w / 2 + _table_side_w / 2 + _table_gap,
_plate_z)
_plate_half_extents = (0.25, _table_side_w / 2, _plate_height)
# Under plate beams
_beam1_pose: ClassVar[Pose3D] = (_table_x,
_table2_y - _table_mid_w / 2 - _table_gap / 2,
_plate_z - 4 * _plate_height)
(_plate1_pose[1] + _table2_pose[1]) / 2,
_plate_z - 4 * _plate_height)
_beam2_pose: ClassVar[Pose3D] = (_table_x,
_table2_y + _table_mid_w / 2 + _table_gap / 2,
_plate_z - 4 * _plate_height)
_table_mid_half_extents = [0.1, _table_mid_w / 2,
_table_height / 2] # depth, w, h
_plate_half_extents = (0.25, _table_side_w / 2, _plate_height)
(_plate3_pose[1] + _table2_pose[1]) / 2,
_plate_z - 4 * _plate_height)
_beam_half_extents = [0.01, 0.15, _plate_height]

# Button
# Button on table
_button_radius = 0.04
_button_color_off = [1, 0, 0, 1]
_button_color_on = [0, 1, 0, 1]
button_x, button_y, button_z = _table_x, _table2_y, _table_z
button_press_threshold = 1e-3

# Workspace parameters
x_lb: ClassVar[float] = 1.325
x_ub: ClassVar[float] = 1.375
y_lb: ClassVar[float] = 0.4
y_ub: ClassVar[float] = 1.1
z_lb: ClassVar[float] = _table_height
z_ub: ClassVar[float] = 0.75 + _table_height/2
y_plate1_ub: ClassVar[float] = _plate1_pose[1] + _table_side_w / 2 - 0.1
y_plate3_lb: ClassVar[float] = _plate3_pose[1] - _table_side_w / 2 + 0.1
pick_z: ClassVar[float] = 0.7

# Robot parameters
robot_init_x: ClassVar[float] = (x_lb + x_ub) / 2
robot_init_y: ClassVar[float] = (y_lb + y_ub) / 2
robot_init_z: ClassVar[float] = pick_z
robot_init_z: ClassVar[float] = z_ub - 0.1
held_tol: ClassVar[float] = 0.5
pick_tol: ClassVar[float] = 0.0001
on_tol: ClassVar[float] = 0.01
collision_padding: ClassVar[float] = 2.0
max_position_vel: ClassVar[float] = 2.5
max_angular_vel: ClassVar[float] = np.pi / 4
max_finger_vel: ClassVar[float] = 1.0

_camera_target: ClassVar[Pose3D] = (1.65, 0.75, 0.52)

Expand All @@ -81,15 +82,6 @@ class PyBulletBalanceEnv(PyBulletEnv):
_num_blocks_test = CFG.balance_num_blocks_test

def __init__(self, use_gui: bool = True) -> None:
super().__init__(use_gui)

# Static objects (always exist no matter the settings).
self._robot = Object("robby", self._robot_type)
self._plate1 = Object("plate1", self._plate_type)
# self._table2 = Object("table2", self._plate_type)
self._plate3 = Object("plate3", self._plate_type)
self._machine = Object("mac", self._machine_type)

# Types
bbox_features = ["bbox_left", "bbox_right", "bbox_upper", "bbox_lower"]
self._block_type = Type("block", [
Expand All @@ -103,7 +95,18 @@ def __init__(self, use_gui: bool = True) -> None:
"plate",
["pose_z"] + (bbox_features if CFG.env_include_bbox_features else [])
)
# Predicates
self._machine_type = Type("machine", ["is_on"] + (bbox_features if
CFG.env_include_bbox_features else []))

# Static objects (always exist no matter the settings).
self._robot = Object("robby", self._robot_type)
self._plate1 = Object("plate1", self._plate_type)
# self._table2 = Object("table2", self._plate_type)
self._plate3 = Object("plate3", self._plate_type)
self._machine = Object("mac", self._machine_type)

super().__init__(use_gui)

# Predicates
self._DirectlyOn = Predicate(
"DirectlyOn", [self._block_type, self._block_type],
Expand Down Expand Up @@ -895,7 +898,7 @@ def _make_tasks(self, num_tasks: int, possible_num_blocks: List[int],
# break
# if idx == 0:
tasks.append(EnvironmentTask(init_state, goal))
return tasks
return self._add_pybullet_state_to_tasks(tasks)

def _sample_initial_piles(self, num_blocks: int,
rng: np.random.Generator) -> List[List[Object]]:
Expand Down
8 changes: 2 additions & 6 deletions predicators/ground_truth_models/balance/nsrts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import numpy as np

from predicators.envs.balance import BalanceEnv
from predicators.envs.pybullet_balance import PyBulletBalanceEnv
from predicators.ground_truth_models import GroundTruthNSRTFactory
from predicators.structs import NSRT, Array, GroundAtom, LiftedAtom, Object, \
Expand All @@ -17,16 +16,13 @@ class BalanceGroundTruthNSRTFactory(GroundTruthNSRTFactory):

@classmethod
def get_env_names(cls) -> Set[str]:
return {"blocks", "pybullet_balance", "blocks_clear"}
return {"pybullet_balance"}

@staticmethod
def get_nsrts(env_name: str, types: Dict[str, Type],
predicates: Dict[str, Predicate],
options: Dict[str, ParameterizedOption]) -> Set[NSRT]:
if env_name == "pybullet_balance":
env_cls = PyBulletBalanceEnv
else:
env_cls = BalanceEnv
env_cls = PyBulletBalanceEnv

# Types
block_type = types["block"]
Expand Down
24 changes: 9 additions & 15 deletions predicators/ground_truth_models/balance/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from gym.spaces import Box

from predicators import utils
from predicators.envs.balance import BalanceEnv
from predicators.envs.pybullet_balance import PyBulletBalanceEnv
from predicators.ground_truth_models import GroundTruthOptionFactory
from predicators.pybullet_helpers.controllers import \
Expand All @@ -23,7 +22,7 @@
class BalanceGroundTruthOptionFactory(GroundTruthOptionFactory):
"""Ground-truth options for the (non-pybullet) blocks environment."""

env_cls = BalanceEnv
env_cls = PyBulletBalanceEnv

@classmethod
def get_env_names(cls) -> Set[str]:
Expand Down Expand Up @@ -208,7 +207,7 @@ def policy(state: State, memory: Dict, objects: Sequence[Object],
x_norm, y_norm = params
x = BalanceEnv.x_lb + (BalanceEnv.x_ub - BalanceEnv.x_lb) * x_norm
y = BalanceEnv.y_lb + (BalanceEnv.y_ub - BalanceEnv.y_lb) * y_norm
z = BalanceEnv.table_height + 0.5 * block_size
z = BalanceEnv._table_height + 0.5 * block_size
arr = np.array([x, y, z, 1.0], dtype=np.float32)
arr = np.clip(arr, action_space.low, action_space.high)
return Action(arr)
Expand Down Expand Up @@ -276,9 +275,6 @@ def get_options(cls, env_name: str, types: Dict[str, Type],
plate_type = types["plate"]
block_size = CFG.blocks_block_size

Holding = predicates['Holding']
On = predicates['DirectlyOn']
OnPlate = predicates['DirectlyOnPlate']
GripperOpen = predicates['GripperOpen']
MachineOn = predicates['MachineOn']
Balanced = predicates['Balanced'].untransformed_predicate
Expand Down Expand Up @@ -311,7 +307,7 @@ def close_fingers_func(state: State, objects: Sequence[Object],
# Move to far above the block which we will grasp.
cls._create_blocks_move_to_above_block_option(
name="MoveEndEffectorToPreGrasp",
z_func=lambda _: PyBulletBalanceEnv.pick_z,
z_func=lambda _: cls.env_cls.z_ub,
finger_status="open",
pybullet_robot=pybullet_robot,
option_types=option_types,
Expand All @@ -332,7 +328,7 @@ def close_fingers_func(state: State, objects: Sequence[Object],
# Move back up.
cls._create_blocks_move_to_above_block_option(
name="MoveEndEffectorBackUp",
z_func=lambda _: PyBulletBalanceEnv.pick_z,
z_func=lambda _: cls.env_cls.z_ub,
finger_status="closed",
pybullet_robot=pybullet_robot,
option_types=option_types,
Expand All @@ -350,7 +346,7 @@ def close_fingers_func(state: State, objects: Sequence[Object],
# Move to above the block on which we will stack.
cls._create_blocks_move_to_above_block_option(
name="MoveEndEffectorToPreStack",
z_func=lambda _: PyBulletBalanceEnv.pick_z,
z_func=lambda _: cls.env_cls.z_ub,
finger_status="closed",
pybullet_robot=pybullet_robot,
option_types=option_types,
Expand All @@ -372,7 +368,7 @@ def close_fingers_func(state: State, objects: Sequence[Object],
# Move back up.
cls._create_blocks_move_to_above_block_option(
name="MoveEndEffectorBackUp",
z_func=lambda _: PyBulletBalanceEnv.pick_z,
z_func=lambda _: cls.env_cls.z_ub,
finger_status="open",
pybullet_robot=pybullet_robot,
option_types=option_types,
Expand All @@ -384,23 +380,21 @@ def close_fingers_func(state: State, objects: Sequence[Object],
# PutOnPlate
option_types = [robot_type, plate_type]
params_space = Box(0, 1, (2, ))
place_z = PyBulletBalanceEnv.table_height + \
block_size / 2 + cls._offset_z
PutOnPlate = utils.LinearChainParameterizedOption(
"PutOnPlate",
[
# Move to above the table at the (x, y) where we will place.
cls._create_blocks_move_to_above_table_option(
name="MoveEndEffectorToPrePutOnPlate",
z=PyBulletBalanceEnv.pick_z,
z=lambda _: cls.env_cls.z_ub,
finger_status="closed",
pybullet_robot=pybullet_robot,
option_types=option_types,
params_space=params_space),
# Move down to place.
cls._create_blocks_move_to_above_table_option(
name="MoveEndEffectorToPutOnPlate",
z=place_z,
z=lambda _: cls.env_cls.z_ub - 0.2,
finger_status="closed",
pybullet_robot=pybullet_robot,
option_types=option_types,
Expand All @@ -413,7 +407,7 @@ def close_fingers_func(state: State, objects: Sequence[Object],
# Move back up.
cls._create_blocks_move_to_above_table_option(
name="MoveEndEffectorBackUp",
z=PyBulletBalanceEnv.pick_z,
z=lambda _: cls.env_cls.z_ub,
finger_status="open",
pybullet_robot=pybullet_robot,
option_types=option_types,
Expand Down

0 comments on commit 17fe0ea

Please sign in to comment.