diff --git a/predicators/approaches/bilevel_planning_approach.py b/predicators/approaches/bilevel_planning_approach.py index 66b701bf31..41d596a591 100644 --- a/predicators/approaches/bilevel_planning_approach.py +++ b/predicators/approaches/bilevel_planning_approach.py @@ -31,7 +31,8 @@ def __init__(self, train_tasks: List[Task], task_planning_heuristic: str = "default", max_skeletons_optimized: int = -1, - bilevel_plan_without_sim: Optional[bool] = None) -> None: + bilevel_plan_without_sim: Optional[bool] = None, + option_model: Optional[_OptionModelBase] = None) -> None: super().__init__(initial_predicates, initial_options, types, action_space, train_tasks) if task_planning_heuristic == "default": @@ -43,7 +44,9 @@ def __init__(self, self._task_planning_heuristic = task_planning_heuristic self._max_skeletons_optimized = max_skeletons_optimized self._plan_without_sim = bilevel_plan_without_sim - self._option_model = create_option_model(CFG.option_model_name) + if option_model is None: + option_model = create_option_model(CFG.option_model_name) + self._option_model = option_model self._num_calls = 0 self._last_plan: List[_Option] = [] # used if plan WITH sim self._last_nsrt_plan: List[_GroundNSRT] = [] # plan WITHOUT sim diff --git a/predicators/approaches/bridge_policy_approach.py b/predicators/approaches/bridge_policy_approach.py index d257bbddb3..46810f9110 100644 --- a/predicators/approaches/bridge_policy_approach.py +++ b/predicators/approaches/bridge_policy_approach.py @@ -51,8 +51,9 @@ from predicators.approaches.oracle_approach import OracleApproach from predicators.bridge_policies import BridgePolicyDone, create_bridge_policy from predicators.nsrt_learning.segmentation import segment_trajectory +from predicators.option_model import _OptionModelBase from predicators.settings import CFG -from predicators.structs import Action, BridgeDataset, DefaultState, \ +from predicators.structs import NSRT, Action, BridgeDataset, DefaultState, \ DemonstrationQuery, DemonstrationResponse, InteractionRequest, \ InteractionResult, ParameterizedOption, Predicate, Query, State, Task, \ Type, _Option @@ -69,10 +70,18 @@ def __init__(self, action_space: Box, train_tasks: List[Task], task_planning_heuristic: str = "default", - max_skeletons_optimized: int = -1) -> None: - super().__init__(initial_predicates, initial_options, types, - action_space, train_tasks, task_planning_heuristic, - max_skeletons_optimized) + max_skeletons_optimized: int = -1, + nsrts: Optional[Set[NSRT]] = None, + option_model: Optional[_OptionModelBase] = None) -> None: + super().__init__(initial_predicates, + initial_options, + types, + action_space, + train_tasks, + task_planning_heuristic, + max_skeletons_optimized, + nsrts=nsrts, + option_model=option_model) predicates = self._get_current_predicates() options = initial_options nsrts = self._get_current_nsrts() diff --git a/predicators/approaches/oracle_approach.py b/predicators/approaches/oracle_approach.py index c68f4faab5..ec2dc2907b 100644 --- a/predicators/approaches/oracle_approach.py +++ b/predicators/approaches/oracle_approach.py @@ -13,6 +13,7 @@ from predicators.approaches.bilevel_planning_approach import \ BilevelPlanningApproach from predicators.ground_truth_models import get_gt_nsrts +from predicators.option_model import _OptionModelBase from predicators.settings import CFG from predicators.structs import NSRT, ParameterizedOption, Predicate, Task, \ Type @@ -29,12 +30,22 @@ def __init__(self, train_tasks: List[Task], task_planning_heuristic: str = "default", max_skeletons_optimized: int = -1, - bilevel_plan_without_sim: Optional[bool] = None) -> None: - super().__init__(initial_predicates, initial_options, types, - action_space, train_tasks, task_planning_heuristic, - max_skeletons_optimized, bilevel_plan_without_sim) - self._nsrts = get_gt_nsrts(CFG.env, self._initial_predicates, - self._initial_options) + bilevel_plan_without_sim: Optional[bool] = None, + nsrts: Optional[Set[NSRT]] = None, + option_model: Optional[_OptionModelBase] = None) -> None: + super().__init__(initial_predicates, + initial_options, + types, + action_space, + train_tasks, + task_planning_heuristic, + max_skeletons_optimized, + bilevel_plan_without_sim, + option_model=option_model) + if nsrts is None: + nsrts = get_gt_nsrts(CFG.env, self._initial_predicates, + self._initial_options) + self._nsrts = nsrts @classmethod def get_name(cls) -> str: diff --git a/predicators/option_model.py b/predicators/option_model.py index 5e79f0d3d9..45d7ed244a 100644 --- a/predicators/option_model.py +++ b/predicators/option_model.py @@ -7,15 +7,16 @@ from __future__ import annotations import abc -from typing import Tuple +from typing import Callable, Set, Tuple import numpy as np from predicators import utils -from predicators.envs import BaseEnv, create_new_env +from predicators.envs import create_new_env from predicators.ground_truth_models import get_gt_options from predicators.settings import CFG -from predicators.structs import DefaultState, State, _Option +from predicators.structs import Action, DefaultState, ParameterizedOption, \ + State, _Option def create_option_model(name: str) -> _OptionModelBase: @@ -24,13 +25,15 @@ def create_option_model(name: str) -> _OptionModelBase: env = create_new_env(CFG.env, do_cache=False, use_gui=CFG.option_model_use_gui) - return _OracleOptionModel(env) + options = get_gt_options(env.get_name()) + return _OracleOptionModel(options, env.simulate) if name.startswith("oracle"): env_name = name[name.index("_") + 1:] env = create_new_env(env_name, do_cache=False, use_gui=CFG.option_model_use_gui) - return _OracleOptionModel(env) + options = get_gt_options(env.get_name()) + return _OracleOptionModel(options, env.simulate) raise NotImplementedError(f"Unknown option model: {name}") @@ -55,11 +58,11 @@ class _OracleOptionModel(_OptionModelBase): Runs options through this simulator to figure out the next state. """ - def __init__(self, env: BaseEnv) -> None: + def __init__(self, options: Set[ParameterizedOption], + simulator: Callable[[State, Action], State]) -> None: super().__init__() - gt_options = get_gt_options(env.get_name()) - self._name_to_parameterized_option = {o.name: o for o in gt_options} - self._simulator = env.simulate + self._name_to_parameterized_option = {o.name: o for o in options} + self._simulator = simulator def get_next_state_and_num_actions(self, state: State, option: _Option) -> Tuple[State, int]: diff --git a/tests/approaches/test_oracle_approach.py b/tests/approaches/test_oracle_approach.py index dccbbfcf69..13c8ec3454 100644 --- a/tests/approaches/test_oracle_approach.py +++ b/tests/approaches/test_oracle_approach.py @@ -4,6 +4,7 @@ import numpy as np import pytest +from gym.spaces import Box import predicators.envs.pddl_env from predicators import utils @@ -718,3 +719,91 @@ def test_playroom_get_gt_nsrts(): state, train_task.goal, rng) movedoortodoor_action = movedoortodoor_option.policy(state) assert env.action_space.contains(movedoortodoor_action.arr) + + +def test_external_oracle_approach(): + """Test that it's possible for an external user of predicators to define + their own environment and NSRTs and use the oracle approach.""" + + utils.reset_config({"num_train_tasks": 2, "num_test_tasks": 2}) + + class _ExternalBlocksEnv(BlocksEnv): + """To make sure that the test doesn't pass without using the new NSRTs, + reverse the action space.""" + + @classmethod + def get_name(cls) -> str: + return "external_blocks" + + @property + def action_space(self) -> Box: + original_space = super().action_space + return Box(original_space.low[::-1], + original_space.high[::-1], + dtype=np.float32) + + def simulate(self, state, action): + # Need to rewrite these lines here to avoid assertion in simulate + # that uses action_space. + x, y, z, fingers = action.arr[::-1] + # Infer which transition function to follow + if fingers < 0.5: + return self._transition_pick(state, x, y, z) + if z < self.table_height + self._block_size: + return self._transition_putontable(state, x, y, z) + return self._transition_stack(state, x, y, z) + + env = _ExternalBlocksEnv() + + # Create external options by modifying blocks options. + options = set() + old_option_to_new_option = {} + + def _reverse_policy(original_policy): + + def new_policy(state, memory, objects, params): + action = original_policy(state, memory, objects, params) + return Action(action.arr[::-1]) + + return new_policy + + original_options = get_gt_options("blocks") + for option in original_options: + new_policy = _reverse_policy(option.policy) + new_option = ParameterizedOption(f"external_{option.name}", + option.types, option.params_space, + new_policy, option.initiable, + option.terminal) + options.add(new_option) + old_option_to_new_option[option] = new_option + + # Create the option model. + option_model = _OracleOptionModel(options, env.simulate) + + # Create external NSRTs by just modifying blocks NSRTs. + nsrts = set() + for nsrt in get_gt_nsrts("blocks", env.predicates, original_options): + nsrt_option = old_option_to_new_option[nsrt.option] + sampler = nsrt._sampler # pylint: disable=protected-access + new_nsrt = NSRT(f"external_{nsrt.name}", nsrt.parameters, + nsrt.preconditions, nsrt.add_effects, + nsrt.delete_effects, nsrt.ignore_effects, nsrt_option, + nsrt.option_vars, sampler) + nsrts.add(new_nsrt) + + # Create oracle approach. + train_tasks = [t.task for t in env.get_train_tasks()] + approach = OracleApproach(env.predicates, + options, + env.types, + env.action_space, + train_tasks, + nsrts=nsrts, + option_model=option_model) + + # Get a policy for the first task. + task = train_tasks[0] + policy = approach.solve(task, timeout=500) + + # Verify the policy. + assert _policy_solves_task(policy, task, env.simulate) diff --git a/tests/explorers/test_active_sampler_explorer.py b/tests/explorers/test_active_sampler_explorer.py index a53f16aa29..91d20f7ec4 100644 --- a/tests/explorers/test_active_sampler_explorer.py +++ b/tests/explorers/test_active_sampler_explorer.py @@ -25,9 +25,9 @@ def test_active_sampler_explorer(): "sampler_learner": "oracle", }) env = RegionalBumpyCoverEnv() - nsrts = get_gt_nsrts(env.get_name(), env.predicates, - get_gt_options(env.get_name())) - option_model = _OracleOptionModel(env) + options = get_gt_options(env.get_name()) + nsrts = get_gt_nsrts(env.get_name(), env.predicates, options) + option_model = _OracleOptionModel(options, env.simulate) train_tasks = [t.task for t in env.get_train_tasks()] ground_op_hist = {} competence_models = {} @@ -168,9 +168,9 @@ def test_active_sampler_explorer(): "active_sampler_explore_task_strategy": "success_rate", }) env = RegionalBumpyCoverEnv() - nsrts = get_gt_nsrts(env.get_name(), env.predicates, - get_gt_options(env.get_name())) - option_model = _OracleOptionModel(env) + options = get_gt_options(env.get_name()) + nsrts = get_gt_nsrts(env.get_name(), env.predicates, options) + option_model = _OracleOptionModel(options, env.simulate) train_tasks = [t.task for t in env.get_train_tasks()] ground_op_hist = {} competence_models = {} diff --git a/tests/explorers/test_base_explorer.py b/tests/explorers/test_base_explorer.py index f627c68adf..305b69ffa8 100644 --- a/tests/explorers/test_base_explorer.py +++ b/tests/explorers/test_base_explorer.py @@ -12,9 +12,9 @@ def test_create_explorer(): """Tests for create_explorer.""" utils.reset_config({"env": "cover"}) env = CoverEnv() - nsrts = get_gt_nsrts(env.get_name(), env.predicates, - get_gt_options(env.get_name())) - option_model = _OracleOptionModel(env) + options = get_gt_options(env.get_name()) + nsrts = get_gt_nsrts(env.get_name(), env.predicates, options) + option_model = _OracleOptionModel(options, env.simulate) train_tasks = [t.task for t in env.get_train_tasks()] # Greedy lookahead explorer. state_score_fn = lambda _1, _2: 0.0 diff --git a/tests/explorers/test_exploit_bilevel_planning_explorer.py b/tests/explorers/test_exploit_bilevel_planning_explorer.py index f91b0e3d80..234cf1314c 100644 --- a/tests/explorers/test_exploit_bilevel_planning_explorer.py +++ b/tests/explorers/test_exploit_bilevel_planning_explorer.py @@ -15,9 +15,9 @@ def test_exploit_bilevel_planning_explorer(): "explorer": "exploit_planning", }) env = CoverEnv() - nsrts = get_gt_nsrts(env.get_name(), env.predicates, - get_gt_options(env.get_name())) - option_model = _OracleOptionModel(env) + options = get_gt_options(env.get_name()) + nsrts = get_gt_nsrts(env.get_name(), env.predicates, options) + option_model = _OracleOptionModel(options, env.simulate) train_tasks = [t.task for t in env.get_train_tasks()] explorer = create_explorer("exploit_planning", env.predicates, get_gt_options(env.get_name()), env.types, diff --git a/tests/explorers/test_glib_explorer.py b/tests/explorers/test_glib_explorer.py index 01affe6fc8..89a70d5072 100644 --- a/tests/explorers/test_glib_explorer.py +++ b/tests/explorers/test_glib_explorer.py @@ -17,9 +17,9 @@ def test_glib_explorer(target_predicate): "cover_initial_holding_prob": 0.0, }) env = CoverEnv() - nsrts = get_gt_nsrts(env.get_name(), env.predicates, - get_gt_options(env.get_name())) - option_model = _OracleOptionModel(env) + options = get_gt_options(env.get_name()) + nsrts = get_gt_nsrts(env.get_name(), env.predicates, options) + option_model = _OracleOptionModel(options, env.simulate) train_tasks = [t.task for t in env.get_train_tasks()] # For testing purposes, score everything except target predicate low. score_fn = lambda atoms: target_predicate in str(atoms) @@ -85,9 +85,9 @@ def test_glib_explorer_failure_cases(): "explorer": "glib", }) env = CoverEnv() - nsrts = get_gt_nsrts(env.get_name(), env.predicates, - get_gt_options(env.get_name())) - option_model = _OracleOptionModel(env) + options = get_gt_options(env.get_name()) + nsrts = get_gt_nsrts(env.get_name(), env.predicates, options) + option_model = _OracleOptionModel(options, env.simulate) train_tasks = [t.task for t in env.get_train_tasks()] score_fn = lambda _: 0.0 task_idx = 0 diff --git a/tests/explorers/test_greedy_lookahead_explorer.py b/tests/explorers/test_greedy_lookahead_explorer.py index 387234ce20..2c0340b9ee 100644 --- a/tests/explorers/test_greedy_lookahead_explorer.py +++ b/tests/explorers/test_greedy_lookahead_explorer.py @@ -18,9 +18,9 @@ def test_greedy_lookahead_explorer(target_predicate): "cover_initial_holding_prob": 0.0, }) env = CoverEnv() - nsrts = get_gt_nsrts(env.get_name(), env.predicates, - get_gt_options(env.get_name())) - option_model = _OracleOptionModel(env) + options = get_gt_options(env.get_name()) + nsrts = get_gt_nsrts(env.get_name(), env.predicates, options) + option_model = _OracleOptionModel(options, env.simulate) train_tasks = [t.task for t in env.get_train_tasks()] # For testing purposes, score everything except target predicate low. score_fn = lambda atoms, _: target_predicate in str(atoms) @@ -63,9 +63,9 @@ def test_greedy_lookahead_explorer_failure_cases(): "explorer": "greedy_lookahead", }) env = CoverEnv() - nsrts = get_gt_nsrts(env.get_name(), env.predicates, - get_gt_options(env.get_name())) - option_model = _OracleOptionModel(env) + options = get_gt_options(env.get_name()) + nsrts = get_gt_nsrts(env.get_name(), env.predicates, options) + option_model = _OracleOptionModel(options, env.simulate) train_tasks = [t.task for t in env.get_train_tasks()] state_score_fn = lambda _1, _2: 0.0 task_idx = 0 diff --git a/tests/test_option_model.py b/tests/test_option_model.py index c8e168571c..60a436f4d8 100644 --- a/tests/test_option_model.py +++ b/tests/test_option_model.py @@ -53,9 +53,9 @@ def simulate(state, action): class _MockOracleOptionModel(_OracleOptionModel): - def __init__(self, env) -> None: # pylint: disable=super-init-not-called + def __init__(self, simulator) -> None: # pylint: disable=super-init-not-called self._name_to_parameterized_option = {"Pick": parameterized_option} - self._simulator = env.simulate + self._simulator = simulator # Mock option. parameterized_option = ParameterizedOption("Pick", [], params_space, @@ -76,7 +76,7 @@ def __init__(self, env) -> None: # pylint: disable=super-init-not-called obj4: [8, 9, 10], obj9: [11, 12, 13] }) - model = _MockOracleOptionModel(env) + model = _MockOracleOptionModel(env.simulate) next_state, num_act = model.get_next_state_and_num_actions(state, option1) assert num_act == 5 # Test that the option's memory has not been updated. @@ -143,15 +143,15 @@ def simulate(state, action): class _MockOracleOptionModel(_OracleOptionModel): - def __init__(self, env) -> None: # pylint: disable=super-init-not-called + def __init__(self, simulator) -> None: # pylint: disable=super-init-not-called self._name_to_parameterized_option = { "InfiniteLearnedOption": infinite_param_opt } - self._simulator = env.simulate + self._simulator = simulator infinite_option = infinite_param_opt.ground([], params1) env = _NoopMockEnv() - model = _MockOracleOptionModel(env) + model = _MockOracleOptionModel(env.simulate) next_state, num_act = model.get_next_state_and_num_actions( state, infinite_option) assert next_state.allclose(state) @@ -163,7 +163,7 @@ def __init__(self, env) -> None: # pylint: disable=super-init-not-called "max_num_steps_option_rollout": 5, }) - model = _MockOracleOptionModel(env) + model = _MockOracleOptionModel(env.simulate) _, num_act = model.get_next_state_and_num_actions(state, infinite_option) assert num_act == 5 diff --git a/tests/test_planning.py b/tests/test_planning.py index f3a4a3508f..3262d18d98 100644 --- a/tests/test_planning.py +++ b/tests/test_planning.py @@ -410,14 +410,8 @@ def simulate(state, action): next_state[robby][1] = 1 return next_state - class _MockOracleOptionModel(_OracleOptionModel): - - def __init__(self, env) -> None: # pylint: disable=super-init-not-called - self._name_to_parameterized_option = {o.name: o for o in options} - self._simulator = env.simulate - env = _MockEnv() - option_model = _MockOracleOptionModel(env) + option_model = _OracleOptionModel(options, env.simulate) # Check that sesame_plan is deterministic, over both NSRTs and objects. plan1 = [ (act.name, act.objects) for act in sesame_plan(