diff --git a/predicators/approaches/active_sampler_learning_approach.py b/predicators/approaches/active_sampler_learning_approach.py index 5819a546ac..abaa0769d5 100644 --- a/predicators/approaches/active_sampler_learning_approach.py +++ b/predicators/approaches/active_sampler_learning_approach.py @@ -136,7 +136,7 @@ def load(self, online_learning_cycle: Optional[int]) -> None: self._nsrt_to_explorer_sampler = save_dict["nsrt_to_explorer_sampler"] self._seen_train_task_idxs = save_dict["seen_train_task_idxs"] self._train_tasks = save_dict["train_tasks"] - self._online_learning_cycle = CFG.skip_until_cycle + 1 + self._online_learning_cycle = save_dict["online_learning_cycle"] def _learn_nsrts(self, trajectories: List[LowLevelTrajectory], online_learning_cycle: Optional[int], @@ -192,6 +192,7 @@ def _learn_nsrts(self, trajectories: List[LowLevelTrajectory], # generated before reset with default init states, which # are subsequently overwritten after reset is called. "train_tasks": self._train_tasks, + "online_learning_cycle": self._online_learning_cycle, }, f)