Skip to content

Commit

Permalink
Merge branch 'master' into gymnasium-1.0.0a1
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts authored Oct 8, 2024
2 parents 99571a6 + 56c153f commit 2f62460
Show file tree
Hide file tree
Showing 7 changed files with 74 additions and 4 deletions.
12 changes: 11 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.4.0a8 (WIP)
Release 2.4.0a10 (WIP)
--------------------------

.. note::
Expand All @@ -13,13 +13,21 @@ Release 2.4.0a8 (WIP)
To suppress the warning, simply save the model again.
You can find more info in `PR #1963 <https://github.com/DLR-RM/stable-baselines3/pull/1963>`_

.. warning::

Stable-Baselines3 (SB3) v2.4.0 will be the last one supporting Python 3.8 (end of life in October 2024)
and PyTorch < 2.0.
We highly recommended you to upgrade to Python >= 3.9 and PyTorch >= 2.0.


Breaking Changes:
^^^^^^^^^^^^^^^^^

New Features:
^^^^^^^^^^^^^
- Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ)
- Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle)
- Updated env checker to warn users when using multi-dim array to define `MultiDiscrete` spaces

Bug Fixes:
^^^^^^^^^^
Expand Down Expand Up @@ -52,12 +60,14 @@ Others:
- Fixed various typos (@cschindlbeck)
- Remove unnecessary SDE noise resampling in PPO update (@brn-dev)
- Updated PyTorch version on CI to 2.3.1
- Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy``

Bug Fixes:
^^^^^^^^^^

Documentation:
^^^^^^^^^^^^^^
- Updated PPO doc to recommend using CPU with ``MlpPolicy``

Release 2.3.2 (2024-04-27)
--------------------------
Expand Down
17 changes: 17 additions & 0 deletions docs/modules/ppo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,23 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments.
vec_env.render("human")
.. note::

PPO is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using ``SubprocVecEnv`` instead of the default ``DummyVecEnv``:

.. code-block::
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv
if __name__=="__main__":
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
model = PPO("MlpPolicy", env, device="cpu")
model.learn(total_timesteps=25_000)
For more information, see :ref:`Vectorized Environments <vec_env>`, `Issue #1245 <https://github.com/DLR-RM/stable-baselines3/issues/1245#issuecomment-1435766949>`_ or the `Multiprocessing notebook <https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb>`_.

Results
-------

Expand Down
8 changes: 8 additions & 0 deletions stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
"is not supported but `dict(space2=Box(), spaces3=Box(), spaces4=Discrete())` is."
)

if isinstance(observation_space, spaces.MultiDiscrete) and len(observation_space.nvec.shape) > 1:
warnings.warn(
f"The MultiDiscrete observation space uses a multidimensional array {observation_space.nvec} "
"which is currently not supported by Stable-Baselines3. "
"Please convert it to a 1D array using a wrapper: "
"https://github.com/DLR-RM/stable-baselines3/issues/1836."
)

if isinstance(observation_space, spaces.Tuple):
warnings.warn(
"The observation space is a Tuple, "
Expand Down
23 changes: 23 additions & 0 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import time
import warnings
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

import numpy as np
Expand Down Expand Up @@ -135,6 +136,28 @@ def _setup_model(self) -> None:
self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs
)
self.policy = self.policy.to(self.device)
# Warn when not using CPU with MlpPolicy
self._maybe_recommend_cpu()

def _maybe_recommend_cpu(self, mlp_class_name: str = "ActorCriticPolicy") -> None:
"""
Recommend to use CPU only when using A2C/PPO with MlpPolicy.
:param: The name of the class for the default MlpPolicy.
"""
policy_class_name = self.policy_class.__name__
if self.device != th.device("cpu") and policy_class_name == mlp_class_name:
warnings.warn(
f"You are trying to run {self.__class__.__name__} on the GPU, "
"but it is primarily intended to run on the CPU when not using a CNN policy "
f"(you are using {policy_class_name} which should be a MlpPolicy). "
"See https://github.com/DLR-RM/stable-baselines3/issues/1245 "
"for more info. "
"You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU."
"Note: The model will train, but the GPU utilization will be poor and "
"the training might take longer than on CPU.",
UserWarning,
)

def collect_rollouts(
self,
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.4.0a8
2.4.0a10
2 changes: 2 additions & 0 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def patched_step(_action):
spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}),
# Non zero start index
spaces.Discrete(3, start=-1),
# 2D MultiDiscrete
spaces.MultiDiscrete(np.array([[4, 4], [2, 3]])),
# Non zero start index (MultiDiscrete)
spaces.MultiDiscrete([4, 4], start=[1, 0]),
# Non zero start index inside a Dict
Expand Down
14 changes: 12 additions & 2 deletions tests/test_run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import gymnasium as gym
import numpy as np
import pytest
import torch as th

from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.env_util import make_vec_env
Expand Down Expand Up @@ -211,8 +212,11 @@ def test_warn_dqn_multi_env():


def test_ppo_warnings():
"""Test that PPO warns and errors correctly on
problematic rollout buffer sizes"""
"""
Test that PPO warns and errors correctly on
problematic rollout buffer sizes,
and recommend using CPU.
"""

# Only 1 step: advantage normalization will return NaN
with pytest.raises(AssertionError):
Expand All @@ -234,3 +238,9 @@ def test_ppo_warnings():
loss = model.logger.name_to_value["train/loss"]
assert loss > 0
assert not np.isnan(loss) # check not nan (since nan does not equal nan)

with pytest.warns(UserWarning, match="You are trying to run PPO on the GPU"):
model = PPO("MlpPolicy", "Pendulum-v1")
# Pretend to be on the GPU
model.device = th.device("cuda")
model._maybe_recommend_cpu()

0 comments on commit 2f62460

Please sign in to comment.