Skip to content

Commit

Permalink
allow third party users to define their own oracle NSRTs
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsilver committed Oct 28, 2023
1 parent 3090591 commit d2c5ad2
Show file tree
Hide file tree
Showing 5 changed files with 137 additions and 22 deletions.
7 changes: 5 additions & 2 deletions predicators/approaches/bilevel_planning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def __init__(self,
train_tasks: List[Task],
task_planning_heuristic: str = "default",
max_skeletons_optimized: int = -1,
bilevel_plan_without_sim: Optional[bool] = None) -> None:
bilevel_plan_without_sim: Optional[bool] = None,
option_model: Optional[_OptionModelBase] = None) -> None:
super().__init__(initial_predicates, initial_options, types,
action_space, train_tasks)
if task_planning_heuristic == "default":
Expand All @@ -43,7 +44,9 @@ def __init__(self,
self._task_planning_heuristic = task_planning_heuristic
self._max_skeletons_optimized = max_skeletons_optimized
self._plan_without_sim = bilevel_plan_without_sim
self._option_model = create_option_model(CFG.option_model_name)
if option_model is None:
option_model = create_option_model(CFG.option_model_name)
self._option_model = option_model
self._num_calls = 0
self._last_plan: List[_Option] = [] # used if plan WITH sim
self._last_nsrt_plan: List[_GroundNSRT] = [] # plan WITHOUT sim
Expand Down
19 changes: 14 additions & 5 deletions predicators/approaches/bridge_policy_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@
from predicators.approaches.oracle_approach import OracleApproach
from predicators.bridge_policies import BridgePolicyDone, create_bridge_policy
from predicators.nsrt_learning.segmentation import segment_trajectory
from predicators.option_model import _OptionModelBase
from predicators.settings import CFG
from predicators.structs import Action, BridgeDataset, DefaultState, \
from predicators.structs import NSRT, Action, BridgeDataset, DefaultState, \
DemonstrationQuery, DemonstrationResponse, InteractionRequest, \
InteractionResult, ParameterizedOption, Predicate, Query, State, Task, \
Type, _Option
Expand All @@ -69,10 +70,18 @@ def __init__(self,
action_space: Box,
train_tasks: List[Task],
task_planning_heuristic: str = "default",
max_skeletons_optimized: int = -1) -> None:
super().__init__(initial_predicates, initial_options, types,
action_space, train_tasks, task_planning_heuristic,
max_skeletons_optimized)
max_skeletons_optimized: int = -1,
nsrts: Optional[Set[NSRT]] = None,
option_model: Optional[_OptionModelBase] = None) -> None:
super().__init__(initial_predicates,
initial_options,
types,
action_space,
train_tasks,
task_planning_heuristic,
max_skeletons_optimized,
nsrts=nsrts,
option_model=option_model)
predicates = self._get_current_predicates()
options = initial_options
nsrts = self._get_current_nsrts()
Expand Down
23 changes: 17 additions & 6 deletions predicators/approaches/oracle_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from predicators.approaches.bilevel_planning_approach import \
BilevelPlanningApproach
from predicators.ground_truth_models import get_gt_nsrts
from predicators.option_model import _OptionModelBase
from predicators.settings import CFG
from predicators.structs import NSRT, ParameterizedOption, Predicate, Task, \
Type
Expand All @@ -29,12 +30,22 @@ def __init__(self,
train_tasks: List[Task],
task_planning_heuristic: str = "default",
max_skeletons_optimized: int = -1,
bilevel_plan_without_sim: Optional[bool] = None) -> None:
super().__init__(initial_predicates, initial_options, types,
action_space, train_tasks, task_planning_heuristic,
max_skeletons_optimized, bilevel_plan_without_sim)
self._nsrts = get_gt_nsrts(CFG.env, self._initial_predicates,
self._initial_options)
bilevel_plan_without_sim: Optional[bool] = None,
nsrts: Optional[Set[NSRT]] = None,
option_model: Optional[_OptionModelBase] = None) -> None:
super().__init__(initial_predicates,
initial_options,
types,
action_space,
train_tasks,
task_planning_heuristic,
max_skeletons_optimized,
bilevel_plan_without_sim,
option_model=option_model)
if nsrts is None:
nsrts = get_gt_nsrts(CFG.env, self._initial_predicates,
self._initial_options)
self._nsrts = nsrts

@classmethod
def get_name(cls) -> str:
Expand Down
21 changes: 12 additions & 9 deletions predicators/option_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
from __future__ import annotations

import abc
from typing import Tuple
from typing import Callable, Set, Tuple

import numpy as np

from predicators import utils
from predicators.envs import BaseEnv, create_new_env
from predicators.envs import create_new_env
from predicators.ground_truth_models import get_gt_options
from predicators.settings import CFG
from predicators.structs import DefaultState, State, _Option
from predicators.structs import Action, DefaultState, ParameterizedOption, \
State, _Option


def create_option_model(name: str) -> _OptionModelBase:
Expand All @@ -24,13 +25,15 @@ def create_option_model(name: str) -> _OptionModelBase:
env = create_new_env(CFG.env,
do_cache=False,
use_gui=CFG.option_model_use_gui)
return _OracleOptionModel(env)
options = get_gt_options(env.get_name())
return _OracleOptionModel(options, env.simulate)
if name.startswith("oracle"):
env_name = name[name.index("_") + 1:]
env = create_new_env(env_name,
do_cache=False,
use_gui=CFG.option_model_use_gui)
return _OracleOptionModel(env)
options = get_gt_options(env.get_name())
return _OracleOptionModel(options, env.simulate)
raise NotImplementedError(f"Unknown option model: {name}")


Expand All @@ -55,11 +58,11 @@ class _OracleOptionModel(_OptionModelBase):
Runs options through this simulator to figure out the next state.
"""

def __init__(self, env: BaseEnv) -> None:
def __init__(self, options: Set[ParameterizedOption],
simulator: Callable[[State, Action], State]) -> None:
super().__init__()
gt_options = get_gt_options(env.get_name())
self._name_to_parameterized_option = {o.name: o for o in gt_options}
self._simulator = env.simulate
self._name_to_parameterized_option = {o.name: o for o in options}
self._simulator = simulator

def get_next_state_and_num_actions(self, state: State,
option: _Option) -> Tuple[State, int]:
Expand Down
89 changes: 89 additions & 0 deletions tests/approaches/test_oracle_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import pytest
from gym.spaces import Box

import predicators.envs.pddl_env
from predicators import utils
Expand Down Expand Up @@ -718,3 +719,91 @@ def test_playroom_get_gt_nsrts():
state, train_task.goal, rng)
movedoortodoor_action = movedoortodoor_option.policy(state)
assert env.action_space.contains(movedoortodoor_action.arr)


def test_external_oracle_approach():
"""Test that it's possible for an external user of predicators to define
their own environment and NSRTs and use the oracle approach."""

utils.reset_config({"num_train_tasks": 2, "num_test_tasks": 2})

class _ExternalBlocksEnv(BlocksEnv):
"""To make sure that the test doesn't pass without using the new NSRTs,
reverse the action space."""

@classmethod
def get_name(cls) -> str:
return "external_blocks"

@property
def action_space(self) -> Box:
original_space = super().action_space
return Box(original_space.low[::-1],
original_space.high[::-1],
dtype=original_space.dtype)

def simulate(self, state, action):
# Need to rewrite these lines here to avoid assertion in simulate
# that uses action_space.
x, y, z, fingers = action.arr[::-1]
# Infer which transition function to follow
if fingers < 0.5:
return self._transition_pick(state, x, y, z)
if z < self.table_height + self._block_size:
return self._transition_putontable(state, x, y, z)
return self._transition_stack(state, x, y, z)

env = _ExternalBlocksEnv()

# Create external options by modifying blocks options.
options = set()
old_option_to_new_option = {}

def _reverse_policy(original_policy):

def new_policy(state, memory, objects, params):
action = original_policy(state, memory, objects, params)
return Action(action.arr[::-1])

return new_policy

original_options = get_gt_options("blocks")
for option in original_options:
new_policy = _reverse_policy(option.policy)
new_option = ParameterizedOption(f"external_{option.name}",
option.types, option.params_space,
new_policy, option.initiable,
option.terminal)
options.add(new_option)
old_option_to_new_option[option] = new_option

# Create the option model.
option_model = _OracleOptionModel(options, env.simulate)

# Create external NSRTs by just modifying blocks NSRTs.
nsrts = set()
for nsrt in get_gt_nsrts("blocks", env.predicates, original_options):
nsrt_option = old_option_to_new_option[nsrt.option]
sampler = nsrt._sampler # pylint: disable=protected-access
new_nsrt = NSRT(f"external_{nsrt.name}", nsrt.parameters,
nsrt.preconditions, nsrt.add_effects,
nsrt.delete_effects, nsrt.ignore_effects, nsrt_option,
nsrt.option_vars, sampler)
nsrts.add(new_nsrt)

# Create oracle approach.
train_tasks = [t.task for t in env.get_train_tasks()]
approach = OracleApproach(env.predicates,
options,
env.types,
env.action_space,
train_tasks,
nsrts=nsrts,
option_model=option_model)

# Get a policy for the first task.
task = train_tasks[0]
policy = approach.solve(task, timeout=500)

# Verify the policy.
assert _policy_solves_task(policy, task, env.simulate)

0 comments on commit d2c5ad2

Please sign in to comment.