diff --git a/predicators/explorers/active_sampler_explorer.py b/predicators/explorers/active_sampler_explorer.py index 6f23e7ebff..898f772a35 100644 --- a/predicators/explorers/active_sampler_explorer.py +++ b/predicators/explorers/active_sampler_explorer.py @@ -57,7 +57,10 @@ def __init__(self, predicates: Set[Predicate], self._last_executed_nsrt: Optional[_GroundNSRT] = None self._nsrt_to_explorer_sampler = nsrt_to_explorer_sampler self._seen_train_task_idxs = seen_train_task_idxs - self._task_plan_cache: Dict[_TaskID, List[_GroundSTRIPSOperator]] = {} + # If the plan is None, that means none can be found, e.g., due to + # timeouts or dead-ends. + self._task_plan_cache: Dict[ + _TaskID, Optional[List[_GroundSTRIPSOperator]]] = {} self._task_plan_calls_since_replan: Dict[_TaskID, int] = {} self._sorted_options = sorted(options, key=lambda o: o.name) @@ -410,17 +413,20 @@ def _score_ground_op_planning_progress( replan_task_ids = [("replan", i) for i in range(len(num_replan_tasks))] for task_id in train_task_ids + replan_task_ids: plan = self._get_task_plan_for_task(task_id, ground_op_costs) - task_plan_costs = [] - for op in plan: - op_cost = ground_op_costs.get(op, self._default_cost) - task_plan_costs.append(op_cost) - plan_costs.append(sum(task_plan_costs)) + # If no plan can be found for a task, the task is just ignored. + if plan is not None: + task_plan_costs = [] + for op in plan: + op_cost = ground_op_costs.get(op, self._default_cost) + task_plan_costs.append(op_cost) + plan_costs.append(sum(task_plan_costs)) return -sum(plan_costs) # higher scores are better def _get_task_plan_for_task( self, task_id: _TaskID, ground_op_costs: Dict[_GroundSTRIPSOperator, float] - ) -> List[_GroundSTRIPSOperator]: + ) -> Optional[List[_GroundSTRIPSOperator]]: + """Returns None if no task plan can be found.""" # Optimization: only re-plan at a certain frequency. replan_freq = CFG.active_sampler_explorer_replan_frequency if task_id not in self._task_plan_calls_since_replan or \ @@ -433,18 +439,22 @@ def _get_task_plan_for_task( "train": self._train_tasks, "replan": self._replanning_tasks, }[task_type][task_idx] - plan, _, _ = run_task_plan_once( - task, - self._nsrts, - self._predicates, - self._types, - timeout, - self._seed, - task_planning_heuristic=task_planning_heuristic, - ground_op_costs=ground_op_costs, - default_cost=self._default_cost, - max_horizon=np.inf) - self._task_plan_cache[task_id] = [n.op for n in plan] + try: + plan, _, _ = run_task_plan_once( + task, + self._nsrts, + self._predicates, + self._types, + timeout, + self._seed, + task_planning_heuristic=task_planning_heuristic, + ground_op_costs=ground_op_costs, + default_cost=self._default_cost, + max_horizon=np.inf) + self._task_plan_cache[task_id] = [n.op for n in plan] + except (PlanningFailure, PlanningTimeout): # pragma: no cover + logging.info("WARNING: task planning failed in the explorer.") + self._task_plan_cache[task_id] = None self._task_plan_calls_since_replan[task_id] += 1 return self._task_plan_cache[task_id]