diff --git a/predicators/approaches/active_sampler_learning_approach.py b/predicators/approaches/active_sampler_learning_approach.py index 1c0f711e15..5819a546ac 100644 --- a/predicators/approaches/active_sampler_learning_approach.py +++ b/predicators/approaches/active_sampler_learning_approach.py @@ -127,6 +127,7 @@ def load(self, online_learning_cycle: Optional[int]) -> None: save_path = utils.get_approach_load_path_str() with open(f"{save_path}_{online_learning_cycle}.DATA", "rb") as f: save_dict = pkl.load(f) + self._dataset = save_dict["dataset"] self._sampler_data = save_dict["sampler_data"] self._ground_op_hist = save_dict["ground_op_hist"] self._competence_models = save_dict["competence_models"] @@ -154,12 +155,31 @@ def _learn_nsrts(self, trajectories: List[LowLevelTrajectory], # Advance the competence models. for competence_model in self._competence_models.values(): competence_model.advance_cycle() + # Sanity check that the ground op histories and sampler data are sync. + op_to_num_ground_op_hist: Dict[str, int] = { + n.name: 0 + for n in self._sampler_data + } + for ground_op, examples in self._ground_op_hist.items(): + num = len(examples) + name = ground_op.parent.name + assert name in op_to_num_ground_op_hist + op_to_num_ground_op_hist[name] += num + for op, op_sampler_data in self._sampler_data.items(): + # The only case where there should be more sampler data than ground + # op hist is if we started out with a nontrivial dataset. That + # dataset is not included in the ground op hist. + num_ground_op = op_to_num_ground_op_hist[op.name] + num_sampler = len(op_sampler_data) + assert num_ground_op == num_sampler or \ + (num_sampler > num_ground_op and CFG.max_initial_demos > 0) # Save the things we need other than the NSRTs, which were already # saved in the above call to self._learn_nsrts() save_path = utils.get_approach_save_path_str() with open(f"{save_path}_{online_learning_cycle}.DATA", "wb") as f: pkl.dump( { + "dataset": self._dataset, "sampler_data": self._sampler_data, "ground_op_hist": self._ground_op_hist, "competence_models": self._competence_models, diff --git a/predicators/explorers/base_explorer.py b/predicators/explorers/base_explorer.py index 0cddc746cf..2efd7d9e4b 100644 --- a/predicators/explorers/base_explorer.py +++ b/predicators/explorers/base_explorer.py @@ -58,13 +58,13 @@ def wrapped_termination_fn(state: State) -> bool: if termination_fn(state): logging.info("[Base Explorer] terminating due to term fn") return True - if remaining_steps <= 0: - logging.info("[Base Explorer] terminating due to max steps") - return True steps_taken = self._max_steps_before_termination - remaining_steps actual_remaining_steps = min( remaining_steps, CFG.max_num_steps_interaction_request - steps_taken) + if actual_remaining_steps <= 0: + logging.info("[Base Explorer] terminating due to max steps") + return True logging.info( "[Base Explorer] not yet terminating (remaining steps: " f"{actual_remaining_steps})") diff --git a/predicators/main.py b/predicators/main.py index 91620dc73c..36739cb809 100644 --- a/predicators/main.py +++ b/predicators/main.py @@ -277,7 +277,7 @@ def _generate_interaction_results( env, "train", request.train_task_idx, - max_num_steps=CFG.max_num_steps_interaction_request, + max_num_steps=(CFG.max_num_steps_interaction_request + 1), terminate_on_goal_reached=False, exceptions_to_break_on={ utils.EnvironmentFailure, diff --git a/tests/explorers/test_online_learning.py b/tests/explorers/test_online_learning.py index 4bae62b2be..be8db3a5ea 100644 --- a/tests/explorers/test_online_learning.py +++ b/tests/explorers/test_online_learning.py @@ -62,7 +62,7 @@ def act_policy3(state): return [request1, request2, request3] def learn_from_interaction_results(self, results): - max_steps = CFG.max_num_steps_interaction_request + max_steps = CFG.max_num_steps_interaction_request + 1 assert len(results) == 3 result1, result2, result3 = results assert isinstance(result1, InteractionResult)