Skip to content

Commit

Permalink
experimental setup
Browse files Browse the repository at this point in the history
  • Loading branch information
barci2 committed Nov 19, 2023
1 parent e6be2c0 commit 27515e9
Show file tree
Hide file tree
Showing 4 changed files with 719 additions and 0 deletions.
130 changes: 130 additions & 0 deletions experiments/nsrts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@

from typing import ClassVar, Dict, Sequence, Set, Type

import numpy as np
from experiments.shelves2d import Shelves2DEnv
from predicators.ground_truth_models import GroundTruthNSRTFactory
from predicators.structs import NSRT, Action, Array, GroundAtom, NSRTSampler, Object, ParameterizedOption, Predicate, State, Variable

__all__ = ["Shelves2DGroundTruthNSRTFactory"]

class Shelves2DGroundTruthNSRTFactory(GroundTruthNSRTFactory):
"""Ground-truth NSRTs for the blocks environment."""

margin: ClassVar[float] = 0.0001

@classmethod
def get_env_names(cls) -> Set[str]:
return {"shelves2d"}

@staticmethod
def get_nsrts(env_name: str, types: Dict[str, Type],
predicates: Dict[str, Predicate],
options: Dict[str, ParameterizedOption]) -> Set[NSRT]:

margin = Shelves2DGroundTruthNSRTFactory.margin

# Types
box_type = types["box"]
shelf_type = types["shelf"]
bundle_type = types["bundle"]
cover_type = types["cover"]

# Prediactes
GoesInto = predicates["GoesInto"]
CoverAtStart = predicates["CoverAtStart"]
Bundles = predicates["Bundles"]
In = predicates["In"]
CoverFor = predicates["CoverFor"]
CoversTop = predicates["CoversFront"]
CoversBottom = predicates["CoversBack"]

# Options
Move = options["Move"]

nsrts = set()

# MoveCoverToTop
cover = Variable("?cover", cover_type)
bundle = Variable("?bundle", bundle_type)
nsrts.add(NSRT(
"MoveCoverToTop",
[cover, bundle],
{CoverFor([cover, bundle])},
{CoversTop([cover, bundle])},
{CoverAtStart([cover])},
set(),
Move,
set(),
_MoveCover_sampler_helper(True, margin=margin)
))

# MoveCoverToBottom
cover = Variable("?cover", cover_type)
bundle = Variable("?bundle", bundle_type)
nsrts.add(NSRT(
"MoveCoverToBottom",
[cover, bundle],
{CoverFor([cover, bundle])},
{CoversBottom([cover, bundle])},
{CoverAtStart([cover])},
set(),
Move,
set(),
_MoveCover_sampler_helper(False, margin=margin)
))

# InsertBox
box = Variable("?box", box_type)
shelf = Variable("?shelf", shelf_type)
bundle = Variable("?bundle", bundle_type)
cover = Variable("?cover", cover_type)

def InsertBox_sampler(state: State, goal: Set[GroundAtom], rng: np.random.Generator, objects: Sequence[Object]) -> Array:
box, shelf, bundle, cover = objects

box_x, box_y, box_w, box_h = Shelves2DEnv.get_shape_data(state, box)
shelf_x, shelf_y, shelf_w, shelf_h = Shelves2DEnv.get_shape_data(state, shelf)

grasp_x, grasp_y = rng.uniform([box_x + margin, box_y + margin], [box_x + box_w - margin, box_y + box_h - margin])
dx = rng.uniform(shelf_x + margin - box_x, shelf_x + shelf_w - margin - box_w - box_x)

dy_range = [shelf_y + margin - box_h - box_y, shelf_y + shelf_h - margin - box_y]
if CoversBottom([cover, bundle]) in goal:
dy_range[0] = shelf_y + margin - box_y
if CoversTop([cover, bundle]) in goal:
dy_range[1] = shelf_y + shelf_h - margin - box_h - box_y

dy = rng.uniform(*dy_range)
action = np.array([grasp_x, grasp_y, dx, dy])
return action

nsrts.add(NSRT(
"InsertBox",
[box, shelf, bundle, cover],
{GoesInto([box, shelf]), Bundles([bundle, shelf]), CoverFor([cover, bundle]), CoverAtStart([cover])},
{In([box, shelf])},
set(),
set(),
Move,
set(),
InsertBox_sampler
))

return nsrts

def _MoveCover_sampler_helper(move_to_top: bool, margin = 0.0001) -> NSRTSampler:
def _MoveCover_sampler(state: State, goal: Set[GroundAtom], rng: np.random.Generator, objects: Sequence[Object]) -> Action:
cover, bundle = objects

cover_x, cover_y, cover_w, cover_h = Shelves2DEnv.get_shape_data(state, cover)
bundle_x, bundle_y, bundle_w, bundle_h = Shelves2DEnv.get_shape_data(state, bundle)

grasp_x, grasp_y = rng.uniform([cover_x, cover_y], [cover_x + cover_w, cover_y + cover_h])

dy = bundle_y - cover_y + (bundle_h + margin if move_to_top else -cover_h - margin)
dx = bundle_x - cover_x

action = np.array([grasp_x, grasp_y, dx, dy])
return action
return _MoveCover_sampler
52 changes: 52 additions & 0 deletions experiments/options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Ground truth options for the shelves2d environment"""

import gym
from typing import Dict, Sequence, Set, Type, cast

import numpy as np
from experiments.shelves2d import Shelves2DEnv
from predicators.ground_truth_models import GroundTruthOptionFactory
from predicators.structs import Action, Array, Object, ParameterizedOption, Predicate, State
from predicators.utils import SingletonParameterizedOption

class Shelves2DGroundTruthOptionFactory(GroundTruthOptionFactory):
"""Ground truth options for the shelves2d environment"""

@classmethod
def get_env_names(cls) -> Set[str]:
return set(["shelves2d"])

@classmethod
def get_options(cls, env_name: str, types: Dict[str, Type],
predicates: Dict[str, Predicate],
action_space: gym.spaces.Box) -> Set[ParameterizedOption]:

return set([
ParameterizedOption(
"Move",
[],
action_space,
cls.move,
cls.initiable,
cls.terminal
)
])

@classmethod
def initiable(cls, state: State, data: Dict, objects: Sequence[Object], arr: Array) -> bool:
data["start_state"] = state
return True

@classmethod
def terminal(cls, state: State, data: Dict, objects: Sequence[Object], arr: Array) -> bool:
start_state = cast(State, data["start_state"])
assert state.data.keys() == start_state.data.keys()
for obj in state.data:
if not np.allclose(state.data[obj], start_state.data[obj]):
return True
return False

@classmethod
def move(cls, state: State, data: Dict, objects: Sequence[Object], arr: Array) -> Action:
return Action(arr)

Loading

0 comments on commit 27515e9

Please sign in to comment.