Skip to content

Commit

Permalink
Merge branch 'master' into split-fail-focus
Browse files Browse the repository at this point in the history
  • Loading branch information
NishanthJKumar authored Oct 30, 2023
2 parents 898b48e + 00c7860 commit 91f8352
Show file tree
Hide file tree
Showing 12 changed files with 169 additions and 60 deletions.
7 changes: 5 additions & 2 deletions predicators/approaches/bilevel_planning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def __init__(self,
train_tasks: List[Task],
task_planning_heuristic: str = "default",
max_skeletons_optimized: int = -1,
bilevel_plan_without_sim: Optional[bool] = None) -> None:
bilevel_plan_without_sim: Optional[bool] = None,
option_model: Optional[_OptionModelBase] = None) -> None:
super().__init__(initial_predicates, initial_options, types,
action_space, train_tasks)
if task_planning_heuristic == "default":
Expand All @@ -43,7 +44,9 @@ def __init__(self,
self._task_planning_heuristic = task_planning_heuristic
self._max_skeletons_optimized = max_skeletons_optimized
self._plan_without_sim = bilevel_plan_without_sim
self._option_model = create_option_model(CFG.option_model_name)
if option_model is None:
option_model = create_option_model(CFG.option_model_name)
self._option_model = option_model
self._num_calls = 0
self._last_plan: List[_Option] = [] # used if plan WITH sim
self._last_nsrt_plan: List[_GroundNSRT] = [] # plan WITHOUT sim
Expand Down
19 changes: 14 additions & 5 deletions predicators/approaches/bridge_policy_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@
from predicators.approaches.oracle_approach import OracleApproach
from predicators.bridge_policies import BridgePolicyDone, create_bridge_policy
from predicators.nsrt_learning.segmentation import segment_trajectory
from predicators.option_model import _OptionModelBase
from predicators.settings import CFG
from predicators.structs import Action, BridgeDataset, DefaultState, \
from predicators.structs import NSRT, Action, BridgeDataset, DefaultState, \
DemonstrationQuery, DemonstrationResponse, InteractionRequest, \
InteractionResult, ParameterizedOption, Predicate, Query, State, Task, \
Type, _Option
Expand All @@ -69,10 +70,18 @@ def __init__(self,
action_space: Box,
train_tasks: List[Task],
task_planning_heuristic: str = "default",
max_skeletons_optimized: int = -1) -> None:
super().__init__(initial_predicates, initial_options, types,
action_space, train_tasks, task_planning_heuristic,
max_skeletons_optimized)
max_skeletons_optimized: int = -1,
nsrts: Optional[Set[NSRT]] = None,
option_model: Optional[_OptionModelBase] = None) -> None:
super().__init__(initial_predicates,
initial_options,
types,
action_space,
train_tasks,
task_planning_heuristic,
max_skeletons_optimized,
nsrts=nsrts,
option_model=option_model)
predicates = self._get_current_predicates()
options = initial_options
nsrts = self._get_current_nsrts()
Expand Down
23 changes: 17 additions & 6 deletions predicators/approaches/oracle_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from predicators.approaches.bilevel_planning_approach import \
BilevelPlanningApproach
from predicators.ground_truth_models import get_gt_nsrts
from predicators.option_model import _OptionModelBase
from predicators.settings import CFG
from predicators.structs import NSRT, ParameterizedOption, Predicate, Task, \
Type
Expand All @@ -29,12 +30,22 @@ def __init__(self,
train_tasks: List[Task],
task_planning_heuristic: str = "default",
max_skeletons_optimized: int = -1,
bilevel_plan_without_sim: Optional[bool] = None) -> None:
super().__init__(initial_predicates, initial_options, types,
action_space, train_tasks, task_planning_heuristic,
max_skeletons_optimized, bilevel_plan_without_sim)
self._nsrts = get_gt_nsrts(CFG.env, self._initial_predicates,
self._initial_options)
bilevel_plan_without_sim: Optional[bool] = None,
nsrts: Optional[Set[NSRT]] = None,
option_model: Optional[_OptionModelBase] = None) -> None:
super().__init__(initial_predicates,
initial_options,
types,
action_space,
train_tasks,
task_planning_heuristic,
max_skeletons_optimized,
bilevel_plan_without_sim,
option_model=option_model)
if nsrts is None:
nsrts = get_gt_nsrts(CFG.env, self._initial_predicates,
self._initial_options)
self._nsrts = nsrts

@classmethod
def get_name(cls) -> str:
Expand Down
21 changes: 12 additions & 9 deletions predicators/option_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@
from __future__ import annotations

import abc
from typing import Tuple
from typing import Callable, Set, Tuple

import numpy as np

from predicators import utils
from predicators.envs import BaseEnv, create_new_env
from predicators.envs import create_new_env
from predicators.ground_truth_models import get_gt_options
from predicators.settings import CFG
from predicators.structs import DefaultState, State, _Option
from predicators.structs import Action, DefaultState, ParameterizedOption, \
State, _Option


def create_option_model(name: str) -> _OptionModelBase:
Expand All @@ -24,13 +25,15 @@ def create_option_model(name: str) -> _OptionModelBase:
env = create_new_env(CFG.env,
do_cache=False,
use_gui=CFG.option_model_use_gui)
return _OracleOptionModel(env)
options = get_gt_options(env.get_name())
return _OracleOptionModel(options, env.simulate)
if name.startswith("oracle"):
env_name = name[name.index("_") + 1:]
env = create_new_env(env_name,
do_cache=False,
use_gui=CFG.option_model_use_gui)
return _OracleOptionModel(env)
options = get_gt_options(env.get_name())
return _OracleOptionModel(options, env.simulate)
raise NotImplementedError(f"Unknown option model: {name}")


Expand All @@ -55,11 +58,11 @@ class _OracleOptionModel(_OptionModelBase):
Runs options through this simulator to figure out the next state.
"""

def __init__(self, env: BaseEnv) -> None:
def __init__(self, options: Set[ParameterizedOption],
simulator: Callable[[State, Action], State]) -> None:
super().__init__()
gt_options = get_gt_options(env.get_name())
self._name_to_parameterized_option = {o.name: o for o in gt_options}
self._simulator = env.simulate
self._name_to_parameterized_option = {o.name: o for o in options}
self._simulator = simulator

def get_next_state_and_num_actions(self, state: State,
option: _Option) -> Tuple[State, int]:
Expand Down
89 changes: 89 additions & 0 deletions tests/approaches/test_oracle_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import pytest
from gym.spaces import Box

import predicators.envs.pddl_env
from predicators import utils
Expand Down Expand Up @@ -718,3 +719,91 @@ def test_playroom_get_gt_nsrts():
state, train_task.goal, rng)
movedoortodoor_action = movedoortodoor_option.policy(state)
assert env.action_space.contains(movedoortodoor_action.arr)


def test_external_oracle_approach():
"""Test that it's possible for an external user of predicators to define
their own environment and NSRTs and use the oracle approach."""

utils.reset_config({"num_train_tasks": 2, "num_test_tasks": 2})

class _ExternalBlocksEnv(BlocksEnv):
"""To make sure that the test doesn't pass without using the new NSRTs,
reverse the action space."""

@classmethod
def get_name(cls) -> str:
return "external_blocks"

@property
def action_space(self) -> Box:
original_space = super().action_space
return Box(original_space.low[::-1],
original_space.high[::-1],
dtype=np.float32)

def simulate(self, state, action):
# Need to rewrite these lines here to avoid assertion in simulate
# that uses action_space.
x, y, z, fingers = action.arr[::-1]
# Infer which transition function to follow
if fingers < 0.5:
return self._transition_pick(state, x, y, z)
if z < self.table_height + self._block_size:
return self._transition_putontable(state, x, y, z)
return self._transition_stack(state, x, y, z)

env = _ExternalBlocksEnv()

# Create external options by modifying blocks options.
options = set()
old_option_to_new_option = {}

def _reverse_policy(original_policy):

def new_policy(state, memory, objects, params):
action = original_policy(state, memory, objects, params)
return Action(action.arr[::-1])

return new_policy

original_options = get_gt_options("blocks")
for option in original_options:
new_policy = _reverse_policy(option.policy)
new_option = ParameterizedOption(f"external_{option.name}",
option.types, option.params_space,
new_policy, option.initiable,
option.terminal)
options.add(new_option)
old_option_to_new_option[option] = new_option

# Create the option model.
option_model = _OracleOptionModel(options, env.simulate)

# Create external NSRTs by just modifying blocks NSRTs.
nsrts = set()
for nsrt in get_gt_nsrts("blocks", env.predicates, original_options):
nsrt_option = old_option_to_new_option[nsrt.option]
sampler = nsrt._sampler # pylint: disable=protected-access
new_nsrt = NSRT(f"external_{nsrt.name}", nsrt.parameters,
nsrt.preconditions, nsrt.add_effects,
nsrt.delete_effects, nsrt.ignore_effects, nsrt_option,
nsrt.option_vars, sampler)
nsrts.add(new_nsrt)

# Create oracle approach.
train_tasks = [t.task for t in env.get_train_tasks()]
approach = OracleApproach(env.predicates,
options,
env.types,
env.action_space,
train_tasks,
nsrts=nsrts,
option_model=option_model)

# Get a policy for the first task.
task = train_tasks[0]
policy = approach.solve(task, timeout=500)

# Verify the policy.
assert _policy_solves_task(policy, task, env.simulate)
12 changes: 6 additions & 6 deletions tests/explorers/test_active_sampler_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ def test_active_sampler_explorer():
"sampler_learner": "oracle",
})
env = RegionalBumpyCoverEnv()
nsrts = get_gt_nsrts(env.get_name(), env.predicates,
get_gt_options(env.get_name()))
option_model = _OracleOptionModel(env)
options = get_gt_options(env.get_name())
nsrts = get_gt_nsrts(env.get_name(), env.predicates, options)
option_model = _OracleOptionModel(options, env.simulate)
train_tasks = [t.task for t in env.get_train_tasks()]
ground_op_hist = {}
competence_models = {}
Expand Down Expand Up @@ -168,9 +168,9 @@ def test_active_sampler_explorer():
"active_sampler_explore_task_strategy": "success_rate",
})
env = RegionalBumpyCoverEnv()
nsrts = get_gt_nsrts(env.get_name(), env.predicates,
get_gt_options(env.get_name()))
option_model = _OracleOptionModel(env)
options = get_gt_options(env.get_name())
nsrts = get_gt_nsrts(env.get_name(), env.predicates, options)
option_model = _OracleOptionModel(options, env.simulate)
train_tasks = [t.task for t in env.get_train_tasks()]
ground_op_hist = {}
competence_models = {}
Expand Down
6 changes: 3 additions & 3 deletions tests/explorers/test_base_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ def test_create_explorer():
"""Tests for create_explorer."""
utils.reset_config({"env": "cover"})
env = CoverEnv()
nsrts = get_gt_nsrts(env.get_name(), env.predicates,
get_gt_options(env.get_name()))
option_model = _OracleOptionModel(env)
options = get_gt_options(env.get_name())
nsrts = get_gt_nsrts(env.get_name(), env.predicates, options)
option_model = _OracleOptionModel(options, env.simulate)
train_tasks = [t.task for t in env.get_train_tasks()]
# Greedy lookahead explorer.
state_score_fn = lambda _1, _2: 0.0
Expand Down
6 changes: 3 additions & 3 deletions tests/explorers/test_exploit_bilevel_planning_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ def test_exploit_bilevel_planning_explorer():
"explorer": "exploit_planning",
})
env = CoverEnv()
nsrts = get_gt_nsrts(env.get_name(), env.predicates,
get_gt_options(env.get_name()))
option_model = _OracleOptionModel(env)
options = get_gt_options(env.get_name())
nsrts = get_gt_nsrts(env.get_name(), env.predicates, options)
option_model = _OracleOptionModel(options, env.simulate)
train_tasks = [t.task for t in env.get_train_tasks()]
explorer = create_explorer("exploit_planning", env.predicates,
get_gt_options(env.get_name()), env.types,
Expand Down
12 changes: 6 additions & 6 deletions tests/explorers/test_glib_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def test_glib_explorer(target_predicate):
"cover_initial_holding_prob": 0.0,
})
env = CoverEnv()
nsrts = get_gt_nsrts(env.get_name(), env.predicates,
get_gt_options(env.get_name()))
option_model = _OracleOptionModel(env)
options = get_gt_options(env.get_name())
nsrts = get_gt_nsrts(env.get_name(), env.predicates, options)
option_model = _OracleOptionModel(options, env.simulate)
train_tasks = [t.task for t in env.get_train_tasks()]
# For testing purposes, score everything except target predicate low.
score_fn = lambda atoms: target_predicate in str(atoms)
Expand Down Expand Up @@ -85,9 +85,9 @@ def test_glib_explorer_failure_cases():
"explorer": "glib",
})
env = CoverEnv()
nsrts = get_gt_nsrts(env.get_name(), env.predicates,
get_gt_options(env.get_name()))
option_model = _OracleOptionModel(env)
options = get_gt_options(env.get_name())
nsrts = get_gt_nsrts(env.get_name(), env.predicates, options)
option_model = _OracleOptionModel(options, env.simulate)
train_tasks = [t.task for t in env.get_train_tasks()]
score_fn = lambda _: 0.0
task_idx = 0
Expand Down
12 changes: 6 additions & 6 deletions tests/explorers/test_greedy_lookahead_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def test_greedy_lookahead_explorer(target_predicate):
"cover_initial_holding_prob": 0.0,
})
env = CoverEnv()
nsrts = get_gt_nsrts(env.get_name(), env.predicates,
get_gt_options(env.get_name()))
option_model = _OracleOptionModel(env)
options = get_gt_options(env.get_name())
nsrts = get_gt_nsrts(env.get_name(), env.predicates, options)
option_model = _OracleOptionModel(options, env.simulate)
train_tasks = [t.task for t in env.get_train_tasks()]
# For testing purposes, score everything except target predicate low.
score_fn = lambda atoms, _: target_predicate in str(atoms)
Expand Down Expand Up @@ -63,9 +63,9 @@ def test_greedy_lookahead_explorer_failure_cases():
"explorer": "greedy_lookahead",
})
env = CoverEnv()
nsrts = get_gt_nsrts(env.get_name(), env.predicates,
get_gt_options(env.get_name()))
option_model = _OracleOptionModel(env)
options = get_gt_options(env.get_name())
nsrts = get_gt_nsrts(env.get_name(), env.predicates, options)
option_model = _OracleOptionModel(options, env.simulate)
train_tasks = [t.task for t in env.get_train_tasks()]
state_score_fn = lambda _1, _2: 0.0
task_idx = 0
Expand Down
14 changes: 7 additions & 7 deletions tests/test_option_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def simulate(state, action):

class _MockOracleOptionModel(_OracleOptionModel):

def __init__(self, env) -> None: # pylint: disable=super-init-not-called
def __init__(self, simulator) -> None: # pylint: disable=super-init-not-called
self._name_to_parameterized_option = {"Pick": parameterized_option}
self._simulator = env.simulate
self._simulator = simulator

# Mock option.
parameterized_option = ParameterizedOption("Pick", [], params_space,
Expand All @@ -76,7 +76,7 @@ def __init__(self, env) -> None: # pylint: disable=super-init-not-called
obj4: [8, 9, 10],
obj9: [11, 12, 13]
})
model = _MockOracleOptionModel(env)
model = _MockOracleOptionModel(env.simulate)
next_state, num_act = model.get_next_state_and_num_actions(state, option1)
assert num_act == 5
# Test that the option's memory has not been updated.
Expand Down Expand Up @@ -143,15 +143,15 @@ def simulate(state, action):

class _MockOracleOptionModel(_OracleOptionModel):

def __init__(self, env) -> None: # pylint: disable=super-init-not-called
def __init__(self, simulator) -> None: # pylint: disable=super-init-not-called
self._name_to_parameterized_option = {
"InfiniteLearnedOption": infinite_param_opt
}
self._simulator = env.simulate
self._simulator = simulator

infinite_option = infinite_param_opt.ground([], params1)
env = _NoopMockEnv()
model = _MockOracleOptionModel(env)
model = _MockOracleOptionModel(env.simulate)
next_state, num_act = model.get_next_state_and_num_actions(
state, infinite_option)
assert next_state.allclose(state)
Expand All @@ -163,7 +163,7 @@ def __init__(self, env) -> None: # pylint: disable=super-init-not-called
"max_num_steps_option_rollout": 5,
})

model = _MockOracleOptionModel(env)
model = _MockOracleOptionModel(env.simulate)
_, num_act = model.get_next_state_and_num_actions(state, infinite_option)
assert num_act == 5

Expand Down
Loading

0 comments on commit 91f8352

Please sign in to comment.