Skip to content

Commit

Permalink
Only convert to numpy if needed
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Jul 12, 2024
1 parent bee9cbe commit fb1a9f7
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion stable_baselines3/common/prioritized_replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ def update_priorities(self, leaf_nodes_indices: np.ndarray, td_errors: th.Tensor
"""
# Update beta schedule
self._current_progress_remaining = progress_remaining
td_errors = td_errors.detach().cpu().numpy().flatten()
if isinstance(td_errors, th.Tensor):
td_errors = td_errors.detach().cpu().numpy().flatten()

for leaf_node_idx, td_error in zip(leaf_nodes_indices, td_errors):
# Proportional prioritization priority = (abs(td_error) + eps) ^ alpha
Expand Down

0 comments on commit fb1a9f7

Please sign in to comment.