From fb1a9f7df3926494218535af422959d4756f8571 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 12 Jul 2024 13:41:34 +0200 Subject: [PATCH] Only convert to numpy if needed --- stable_baselines3/common/prioritized_replay_buffer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/stable_baselines3/common/prioritized_replay_buffer.py b/stable_baselines3/common/prioritized_replay_buffer.py index a97aca24f..41cce35ad 100644 --- a/stable_baselines3/common/prioritized_replay_buffer.py +++ b/stable_baselines3/common/prioritized_replay_buffer.py @@ -247,7 +247,8 @@ def update_priorities(self, leaf_nodes_indices: np.ndarray, td_errors: th.Tensor """ # Update beta schedule self._current_progress_remaining = progress_remaining - td_errors = td_errors.detach().cpu().numpy().flatten() + if isinstance(td_errors, th.Tensor): + td_errors = td_errors.detach().cpu().numpy().flatten() for leaf_node_idx, td_error in zip(leaf_nodes_indices, td_errors): # Proportional prioritization priority = (abs(td_error) + eps) ^ alpha