Skip to content

Commit

Permalink
a few fixes to saving and loading in active sampler learning (#1593)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomsilver authored Dec 4, 2023
1 parent a5c87b2 commit ed233bc
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 5 deletions.
20 changes: 20 additions & 0 deletions predicators/approaches/active_sampler_learning_approach.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def load(self, online_learning_cycle: Optional[int]) -> None:
save_path = utils.get_approach_load_path_str()
with open(f"{save_path}_{online_learning_cycle}.DATA", "rb") as f:
save_dict = pkl.load(f)
self._dataset = save_dict["dataset"]
self._sampler_data = save_dict["sampler_data"]
self._ground_op_hist = save_dict["ground_op_hist"]
self._competence_models = save_dict["competence_models"]
Expand Down Expand Up @@ -154,12 +155,31 @@ def _learn_nsrts(self, trajectories: List[LowLevelTrajectory],
# Advance the competence models.
for competence_model in self._competence_models.values():
competence_model.advance_cycle()
# Sanity check that the ground op histories and sampler data are sync.
op_to_num_ground_op_hist: Dict[str, int] = {
n.name: 0
for n in self._sampler_data
}
for ground_op, examples in self._ground_op_hist.items():
num = len(examples)
name = ground_op.parent.name
assert name in op_to_num_ground_op_hist
op_to_num_ground_op_hist[name] += num
for op, op_sampler_data in self._sampler_data.items():
# The only case where there should be more sampler data than ground
# op hist is if we started out with a nontrivial dataset. That
# dataset is not included in the ground op hist.
num_ground_op = op_to_num_ground_op_hist[op.name]
num_sampler = len(op_sampler_data)
assert num_ground_op == num_sampler or \
(num_sampler > num_ground_op and CFG.max_initial_demos > 0)
# Save the things we need other than the NSRTs, which were already
# saved in the above call to self._learn_nsrts()
save_path = utils.get_approach_save_path_str()
with open(f"{save_path}_{online_learning_cycle}.DATA", "wb") as f:
pkl.dump(
{
"dataset": self._dataset,
"sampler_data": self._sampler_data,
"ground_op_hist": self._ground_op_hist,
"competence_models": self._competence_models,
Expand Down
6 changes: 3 additions & 3 deletions predicators/explorers/base_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ def wrapped_termination_fn(state: State) -> bool:
if termination_fn(state):
logging.info("[Base Explorer] terminating due to term fn")
return True
if remaining_steps <= 0:
logging.info("[Base Explorer] terminating due to max steps")
return True
steps_taken = self._max_steps_before_termination - remaining_steps
actual_remaining_steps = min(
remaining_steps,
CFG.max_num_steps_interaction_request - steps_taken)
if actual_remaining_steps <= 0:
logging.info("[Base Explorer] terminating due to max steps")
return True
logging.info(
"[Base Explorer] not yet terminating (remaining steps: "
f"{actual_remaining_steps})")
Expand Down
2 changes: 1 addition & 1 deletion predicators/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def _generate_interaction_results(
env,
"train",
request.train_task_idx,
max_num_steps=CFG.max_num_steps_interaction_request,
max_num_steps=(CFG.max_num_steps_interaction_request + 1),
terminate_on_goal_reached=False,
exceptions_to_break_on={
utils.EnvironmentFailure,
Expand Down
2 changes: 1 addition & 1 deletion tests/explorers/test_online_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def act_policy3(state):
return [request1, request2, request3]

def learn_from_interaction_results(self, results):
max_steps = CFG.max_num_steps_interaction_request
max_steps = CFG.max_num_steps_interaction_request + 1
assert len(results) == 3
result1, result2, result3 = results
assert isinstance(result1, InteractionResult)
Expand Down

0 comments on commit ed233bc

Please sign in to comment.