-
Notifications
You must be signed in to change notification settings - Fork 19
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
719 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
|
||
from typing import ClassVar, Dict, Sequence, Set, Type | ||
|
||
import numpy as np | ||
from experiments.shelves2d import Shelves2DEnv | ||
from predicators.ground_truth_models import GroundTruthNSRTFactory | ||
from predicators.structs import NSRT, Action, Array, GroundAtom, NSRTSampler, Object, ParameterizedOption, Predicate, State, Variable | ||
|
||
__all__ = ["Shelves2DGroundTruthNSRTFactory"] | ||
|
||
class Shelves2DGroundTruthNSRTFactory(GroundTruthNSRTFactory): | ||
"""Ground-truth NSRTs for the blocks environment.""" | ||
|
||
margin: ClassVar[float] = 0.0001 | ||
|
||
@classmethod | ||
def get_env_names(cls) -> Set[str]: | ||
return {"shelves2d"} | ||
|
||
@staticmethod | ||
def get_nsrts(env_name: str, types: Dict[str, Type], | ||
predicates: Dict[str, Predicate], | ||
options: Dict[str, ParameterizedOption]) -> Set[NSRT]: | ||
|
||
margin = Shelves2DGroundTruthNSRTFactory.margin | ||
|
||
# Types | ||
box_type = types["box"] | ||
shelf_type = types["shelf"] | ||
bundle_type = types["bundle"] | ||
cover_type = types["cover"] | ||
|
||
# Prediactes | ||
GoesInto = predicates["GoesInto"] | ||
CoverAtStart = predicates["CoverAtStart"] | ||
Bundles = predicates["Bundles"] | ||
In = predicates["In"] | ||
CoverFor = predicates["CoverFor"] | ||
CoversTop = predicates["CoversFront"] | ||
CoversBottom = predicates["CoversBack"] | ||
|
||
# Options | ||
Move = options["Move"] | ||
|
||
nsrts = set() | ||
|
||
# MoveCoverToTop | ||
cover = Variable("?cover", cover_type) | ||
bundle = Variable("?bundle", bundle_type) | ||
nsrts.add(NSRT( | ||
"MoveCoverToTop", | ||
[cover, bundle], | ||
{CoverFor([cover, bundle])}, | ||
{CoversTop([cover, bundle])}, | ||
{CoverAtStart([cover])}, | ||
set(), | ||
Move, | ||
set(), | ||
_MoveCover_sampler_helper(True, margin=margin) | ||
)) | ||
|
||
# MoveCoverToBottom | ||
cover = Variable("?cover", cover_type) | ||
bundle = Variable("?bundle", bundle_type) | ||
nsrts.add(NSRT( | ||
"MoveCoverToBottom", | ||
[cover, bundle], | ||
{CoverFor([cover, bundle])}, | ||
{CoversBottom([cover, bundle])}, | ||
{CoverAtStart([cover])}, | ||
set(), | ||
Move, | ||
set(), | ||
_MoveCover_sampler_helper(False, margin=margin) | ||
)) | ||
|
||
# InsertBox | ||
box = Variable("?box", box_type) | ||
shelf = Variable("?shelf", shelf_type) | ||
bundle = Variable("?bundle", bundle_type) | ||
cover = Variable("?cover", cover_type) | ||
|
||
def InsertBox_sampler(state: State, goal: Set[GroundAtom], rng: np.random.Generator, objects: Sequence[Object]) -> Array: | ||
box, shelf, bundle, cover = objects | ||
|
||
box_x, box_y, box_w, box_h = Shelves2DEnv.get_shape_data(state, box) | ||
shelf_x, shelf_y, shelf_w, shelf_h = Shelves2DEnv.get_shape_data(state, shelf) | ||
|
||
grasp_x, grasp_y = rng.uniform([box_x + margin, box_y + margin], [box_x + box_w - margin, box_y + box_h - margin]) | ||
dx = rng.uniform(shelf_x + margin - box_x, shelf_x + shelf_w - margin - box_w - box_x) | ||
|
||
dy_range = [shelf_y + margin - box_h - box_y, shelf_y + shelf_h - margin - box_y] | ||
if CoversBottom([cover, bundle]) in goal: | ||
dy_range[0] = shelf_y + margin - box_y | ||
if CoversTop([cover, bundle]) in goal: | ||
dy_range[1] = shelf_y + shelf_h - margin - box_h - box_y | ||
|
||
dy = rng.uniform(*dy_range) | ||
action = np.array([grasp_x, grasp_y, dx, dy]) | ||
return action | ||
|
||
nsrts.add(NSRT( | ||
"InsertBox", | ||
[box, shelf, bundle, cover], | ||
{GoesInto([box, shelf]), Bundles([bundle, shelf]), CoverFor([cover, bundle]), CoverAtStart([cover])}, | ||
{In([box, shelf])}, | ||
set(), | ||
set(), | ||
Move, | ||
set(), | ||
InsertBox_sampler | ||
)) | ||
|
||
return nsrts | ||
|
||
def _MoveCover_sampler_helper(move_to_top: bool, margin = 0.0001) -> NSRTSampler: | ||
def _MoveCover_sampler(state: State, goal: Set[GroundAtom], rng: np.random.Generator, objects: Sequence[Object]) -> Action: | ||
cover, bundle = objects | ||
|
||
cover_x, cover_y, cover_w, cover_h = Shelves2DEnv.get_shape_data(state, cover) | ||
bundle_x, bundle_y, bundle_w, bundle_h = Shelves2DEnv.get_shape_data(state, bundle) | ||
|
||
grasp_x, grasp_y = rng.uniform([cover_x, cover_y], [cover_x + cover_w, cover_y + cover_h]) | ||
|
||
dy = bundle_y - cover_y + (bundle_h + margin if move_to_top else -cover_h - margin) | ||
dx = bundle_x - cover_x | ||
|
||
action = np.array([grasp_x, grasp_y, dx, dy]) | ||
return action | ||
return _MoveCover_sampler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
"""Ground truth options for the shelves2d environment""" | ||
|
||
import gym | ||
from typing import Dict, Sequence, Set, Type, cast | ||
|
||
import numpy as np | ||
from experiments.shelves2d import Shelves2DEnv | ||
from predicators.ground_truth_models import GroundTruthOptionFactory | ||
from predicators.structs import Action, Array, Object, ParameterizedOption, Predicate, State | ||
from predicators.utils import SingletonParameterizedOption | ||
|
||
class Shelves2DGroundTruthOptionFactory(GroundTruthOptionFactory): | ||
"""Ground truth options for the shelves2d environment""" | ||
|
||
@classmethod | ||
def get_env_names(cls) -> Set[str]: | ||
return set(["shelves2d"]) | ||
|
||
@classmethod | ||
def get_options(cls, env_name: str, types: Dict[str, Type], | ||
predicates: Dict[str, Predicate], | ||
action_space: gym.spaces.Box) -> Set[ParameterizedOption]: | ||
|
||
return set([ | ||
ParameterizedOption( | ||
"Move", | ||
[], | ||
action_space, | ||
cls.move, | ||
cls.initiable, | ||
cls.terminal | ||
) | ||
]) | ||
|
||
@classmethod | ||
def initiable(cls, state: State, data: Dict, objects: Sequence[Object], arr: Array) -> bool: | ||
data["start_state"] = state | ||
return True | ||
|
||
@classmethod | ||
def terminal(cls, state: State, data: Dict, objects: Sequence[Object], arr: Array) -> bool: | ||
start_state = cast(State, data["start_state"]) | ||
assert state.data.keys() == start_state.data.keys() | ||
for obj in state.data: | ||
if not np.allclose(state.data[obj], start_state.data[obj]): | ||
return True | ||
return False | ||
|
||
@classmethod | ||
def move(cls, state: State, data: Dict, objects: Sequence[Object], arr: Array) -> Action: | ||
return Action(arr) | ||
|
Oops, something went wrong.