Skip to content

Commit

Permalink
Fix log_interval handling in OffPolicyAlgorithm (DLR-RM#1708)
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiabir committed Oct 6, 2023
1 parent c6c660e commit 8712997
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def learn(
:param total_timesteps: The total number of samples (env steps) to train on
:param callback: callback(s) called at every step with state of the algorithm.
:param log_interval: The number of episodes before logging.
:param log_interval: The number of rounds (environment interactions + agent updates) between logging.
:param tb_log_name: the name of the run for TensorBoard logging
:param reset_num_timesteps: whether or not to reset the current timestep number (used in logging)
:param progress_bar: Display a progress bar using tqdm and rich.
Expand Down
13 changes: 7 additions & 6 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,8 @@ def learn(
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfOffPolicyAlgorithm:
iteration = 0

total_timesteps, callback = self._setup_learn(
total_timesteps,
callback,
Expand All @@ -327,12 +329,16 @@ def learn(
callback=callback,
learning_starts=self.learning_starts,
replay_buffer=self.replay_buffer,
log_interval=log_interval,
)

if rollout.continue_training is False:
break

iteration += 1

if log_interval is not None and iteration % log_interval == 0:
self._dump_logs()

if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts:
# If no `gradient_steps` is specified,
# do as many gradients steps as steps performed during the rollout
Expand Down Expand Up @@ -502,7 +508,6 @@ def collect_rollouts(
replay_buffer: ReplayBuffer,
action_noise: Optional[ActionNoise] = None,
learning_starts: int = 0,
log_interval: Optional[int] = None,
) -> RolloutReturn:
"""
Collect experiences and store them into a ``ReplayBuffer``.
Expand All @@ -520,7 +525,6 @@ def collect_rollouts(
in addition to the stochastic policy for SAC.
:param learning_starts: Number of steps before learning for the warm-up phase.
:param replay_buffer:
:param log_interval: Log data every ``log_interval`` episodes
:return:
"""
# Switch to eval mode (this affects batch norm / dropout)
Expand Down Expand Up @@ -583,9 +587,6 @@ def collect_rollouts(
kwargs = dict(indices=[idx]) if env.num_envs > 1 else {}
action_noise.reset(**kwargs)

# Log training infos
if log_interval is not None and self._episode_num % log_interval == 0:
self._dump_logs()
callback.on_rollout_end()

return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training)

0 comments on commit 8712997

Please sign in to comment.