Skip to content

Commit

Permalink
Merge branch 'master' into clustering-reverse-engineering
Browse files Browse the repository at this point in the history
  • Loading branch information
ashayathalye authored Oct 31, 2023
2 parents 1d64baa + 00c7860 commit 215d307
Show file tree
Hide file tree
Showing 33 changed files with 2,007 additions and 140 deletions.
8 changes: 6 additions & 2 deletions predicators/approaches/bilevel_planning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def __init__(self,
train_tasks: List[Task],
task_planning_heuristic: str = "default",
max_skeletons_optimized: int = -1,
bilevel_plan_without_sim: Optional[bool] = None) -> None:
bilevel_plan_without_sim: Optional[bool] = None,
option_model: Optional[_OptionModelBase] = None) -> None:
super().__init__(initial_predicates, initial_options, types,
action_space, train_tasks)
if task_planning_heuristic == "default":
Expand All @@ -43,7 +44,9 @@ def __init__(self,
self._task_planning_heuristic = task_planning_heuristic
self._max_skeletons_optimized = max_skeletons_optimized
self._plan_without_sim = bilevel_plan_without_sim
self._option_model = create_option_model(CFG.option_model_name)
if option_model is None:
option_model = create_option_model(CFG.option_model_name)
self._option_model = option_model
self._num_calls = 0
self._last_plan: List[_Option] = [] # used if plan WITH sim
self._last_nsrt_plan: List[_GroundNSRT] = [] # plan WITHOUT sim
Expand Down Expand Up @@ -131,6 +134,7 @@ def _run_task_plan(
timeout,
seed,
task_planning_heuristic=self._task_planning_heuristic,
max_horizon=float(CFG.horizon),
**kwargs)
except PlanningFailure as e:
raise ApproachFailure(e.args[0], e.info)
Expand Down
19 changes: 14 additions & 5 deletions predicators/approaches/bridge_policy_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@
from predicators.approaches.oracle_approach import OracleApproach
from predicators.bridge_policies import BridgePolicyDone, create_bridge_policy
from predicators.nsrt_learning.segmentation import segment_trajectory
from predicators.option_model import _OptionModelBase
from predicators.settings import CFG
from predicators.structs import Action, BridgeDataset, DefaultState, \
from predicators.structs import NSRT, Action, BridgeDataset, DefaultState, \
DemonstrationQuery, DemonstrationResponse, InteractionRequest, \
InteractionResult, ParameterizedOption, Predicate, Query, State, Task, \
Type, _Option
Expand All @@ -69,10 +70,18 @@ def __init__(self,
action_space: Box,
train_tasks: List[Task],
task_planning_heuristic: str = "default",
max_skeletons_optimized: int = -1) -> None:
super().__init__(initial_predicates, initial_options, types,
action_space, train_tasks, task_planning_heuristic,
max_skeletons_optimized)
max_skeletons_optimized: int = -1,
nsrts: Optional[Set[NSRT]] = None,
option_model: Optional[_OptionModelBase] = None) -> None:
super().__init__(initial_predicates,
initial_options,
types,
action_space,
train_tasks,
task_planning_heuristic,
max_skeletons_optimized,
nsrts=nsrts,
option_model=option_model)
predicates = self._get_current_predicates()
options = initial_options
nsrts = self._get_current_nsrts()
Expand Down
8 changes: 8 additions & 0 deletions predicators/approaches/online_nsrt_learning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ def get_interaction_requests(self) -> List[InteractionRequest]:
# is randomly selected.
explorer = self._create_explorer()

# NOTE: this is definitely awkward, but we have to reset this
# info so that if we ever use the execution monitor while doing
# exploration and collecting more data, it doesn't mistakenly
# try to monitor stuff using a previously-saved plan.
self._last_nsrt_plan = []
self._last_atoms_seq = []
self._last_plan = []

# Create the interaction requests.
requests = []
for _ in range(CFG.online_nsrt_learning_requests_per_cycle):
Expand Down
23 changes: 17 additions & 6 deletions predicators/approaches/oracle_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from predicators.approaches.bilevel_planning_approach import \
BilevelPlanningApproach
from predicators.ground_truth_models import get_gt_nsrts
from predicators.option_model import _OptionModelBase
from predicators.settings import CFG
from predicators.structs import NSRT, ParameterizedOption, Predicate, Task, \
Type
Expand All @@ -29,12 +30,22 @@ def __init__(self,
train_tasks: List[Task],
task_planning_heuristic: str = "default",
max_skeletons_optimized: int = -1,
bilevel_plan_without_sim: Optional[bool] = None) -> None:
super().__init__(initial_predicates, initial_options, types,
action_space, train_tasks, task_planning_heuristic,
max_skeletons_optimized, bilevel_plan_without_sim)
self._nsrts = get_gt_nsrts(CFG.env, self._initial_predicates,
self._initial_options)
bilevel_plan_without_sim: Optional[bool] = None,
nsrts: Optional[Set[NSRT]] = None,
option_model: Optional[_OptionModelBase] = None) -> None:
super().__init__(initial_predicates,
initial_options,
types,
action_space,
train_tasks,
task_planning_heuristic,
max_skeletons_optimized,
bilevel_plan_without_sim,
option_model=option_model)
if nsrts is None:
nsrts = get_gt_nsrts(CFG.env, self._initial_predicates,
self._initial_options)
self._nsrts = nsrts

@classmethod
def get_name(cls) -> str:
Expand Down
6 changes: 5 additions & 1 deletion predicators/cogman.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ def step(self, observation: Observation) -> Optional[Action]:
self._exec_monitor.reset(task)
self._exec_monitor.update_approach_info(
self._approach.get_execution_monitoring_info())
assert not self._exec_monitor.step(state)
# We only reset the approach if the override policy is
# None, so this below assertion only works in this
# case.
if self._override_policy is None:
assert not self._exec_monitor.step(state)
assert self._current_policy is not None
act = self._current_policy(state)
self._exec_monitor.update_approach_info(
Expand Down
Loading

0 comments on commit 215d307

Please sign in to comment.