diff --git a/stable_baselines3/common/prioritized_replay_buffer.py b/stable_baselines3/common/prioritized_replay_buffer.py index a97aca24f..41cce35ad 100644 --- a/stable_baselines3/common/prioritized_replay_buffer.py +++ b/stable_baselines3/common/prioritized_replay_buffer.py @@ -247,7 +247,8 @@ def update_priorities(self, leaf_nodes_indices: np.ndarray, td_errors: th.Tensor """ # Update beta schedule self._current_progress_remaining = progress_remaining - td_errors = td_errors.detach().cpu().numpy().flatten() + if isinstance(td_errors, th.Tensor): + td_errors = td_errors.detach().cpu().numpy().flatten() for leaf_node_idx, td_error in zip(leaf_nodes_indices, td_errors): # Proportional prioritization priority = (abs(td_error) + eps) ^ alpha