Skip to content

Commit

Permalink
Fix + raise awareness of subtle bugs with active sampler exploration (#…
Browse files Browse the repository at this point in the history
…1575)

* fix subtle bugs

* yapf
  • Loading branch information
NishanthJKumar authored Oct 18, 2023
1 parent fba285d commit db29a6c
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions predicators/explorers/active_sampler_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def _get_exploration_strategy(self, train_task_idx: int,
assigned_task_horizon = CFG.horizon
current_policy: Optional[Callable[[State], _Option]] = None
next_practice_nsrt: Optional[_GroundNSRT] = None
current_task_repeat_goal: Optional[Set[GroundAtom]] = None
using_random = False

def _option_policy(state: State) -> _Option:
Expand Down Expand Up @@ -162,6 +163,7 @@ def generate_goals() -> Iterator[Set[GroundAtom]]:
logging.info("[Explorer] Pursuing repeat task")

def generate_goals() -> Iterator[Set[GroundAtom]]:
nonlocal current_task_repeat_goal
# Loop through seen tasks in random order. Propose
# their initial abstract states and their goals until
# one is found that is not already achieved.
Expand All @@ -171,12 +173,23 @@ def generate_goals() -> Iterator[Set[GroundAtom]]:
task = self._train_tasks[train_task_idx]
# Can only practice the task if the objects match.
if set(task.init) == set(state):
# If we've already been trying to achieve a
# particular goal, then keep trying to achieve
# it.
if current_task_repeat_goal is not None:
current_pursuit_goal_achieved = all(
a.holds(state)
for a in current_task_repeat_goal)
if not current_pursuit_goal_achieved:
yield current_task_repeat_goal
# Else, figure out the next goal to plan to!
possible_goals = [
task.goal,
utils.abstract(task.init, self._predicates)
]
for goal in possible_goals:
if any(not a.holds(state) for a in goal):
current_task_repeat_goal = goal
yield goal

# Otherwise, practice.
Expand Down Expand Up @@ -212,6 +225,13 @@ def generate_goals() -> Iterator[Set[GroundAtom]]:
# crash in case that assumption is not met.
except (PlanningFailure,
PlanningTimeout): # pragma: no cover
logging.info(
"WARNING: Planning graph is not "
"fully-connected! This violates a key "
"assumption of our active sampler learning "
"framework; ensure you DO NOT see this message "
"if you're running experiments comparing"
"different active sampler learning approaches.")
continue
logging.info("[Explorer] Plan found.")
break
Expand Down

0 comments on commit db29a6c

Please sign in to comment.