diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index cc417a9d3..af83d2302 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.4.0a8 (WIP) +Release 2.4.0a10 (WIP) -------------------------- .. note:: @@ -13,6 +13,13 @@ Release 2.4.0a8 (WIP) To suppress the warning, simply save the model again. You can find more info in `PR #1963 `_ +.. warning:: + + Stable-Baselines3 (SB3) v2.4.0 will be the last one supporting Python 3.8 (end of life in October 2024) + and PyTorch < 2.0. + We highly recommended you to upgrade to Python >= 3.9 and PyTorch >= 2.0. + + Breaking Changes: ^^^^^^^^^^^^^^^^^ @@ -20,6 +27,7 @@ New Features: ^^^^^^^^^^^^^ - Added support for ``pre_linear_modules`` and ``post_linear_modules`` in ``create_mlp`` (useful for adding normalization layers, like in DroQ or CrossQ) - Enabled np.ndarray logging for TensorBoardOutputFormat as histogram (see GH#1634) (@iwishwasaneagle) +- Updated env checker to warn users when using multi-dim array to define `MultiDiscrete` spaces Bug Fixes: ^^^^^^^^^^ @@ -52,12 +60,14 @@ Others: - Fixed various typos (@cschindlbeck) - Remove unnecessary SDE noise resampling in PPO update (@brn-dev) - Updated PyTorch version on CI to 2.3.1 +- Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy`` Bug Fixes: ^^^^^^^^^^ Documentation: ^^^^^^^^^^^^^^ +- Updated PPO doc to recommend using CPU with ``MlpPolicy`` Release 2.3.2 (2024-04-27) -------------------------- diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst index b5e667241..4285cfb50 100644 --- a/docs/modules/ppo.rst +++ b/docs/modules/ppo.rst @@ -88,6 +88,23 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments. vec_env.render("human") +.. note:: + + PPO is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using ``SubprocVecEnv`` instead of the default ``DummyVecEnv``: + + .. code-block:: + + from stable_baselines3 import PPO + from stable_baselines3.common.env_util import make_vec_env + from stable_baselines3.common.vec_env import SubprocVecEnv + + if __name__=="__main__": + env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv) + model = PPO("MlpPolicy", env, device="cpu") + model.learn(total_timesteps=25_000) + + For more information, see :ref:`Vectorized Environments `, `Issue #1245 `_ or the `Multiprocessing notebook `_. + Results ------- diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 090d609ba..e47dd123a 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -98,6 +98,14 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act "is not supported but `dict(space2=Box(), spaces3=Box(), spaces4=Discrete())` is." ) + if isinstance(observation_space, spaces.MultiDiscrete) and len(observation_space.nvec.shape) > 1: + warnings.warn( + f"The MultiDiscrete observation space uses a multidimensional array {observation_space.nvec} " + "which is currently not supported by Stable-Baselines3. " + "Please convert it to a 1D array using a wrapper: " + "https://github.com/DLR-RM/stable-baselines3/issues/1836." + ) + if isinstance(observation_space, spaces.Tuple): warnings.warn( "The observation space is a Tuple, " diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 262453721..dc885242e 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -1,5 +1,6 @@ import sys import time +import warnings from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import numpy as np @@ -135,6 +136,28 @@ def _setup_model(self) -> None: self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs ) self.policy = self.policy.to(self.device) + # Warn when not using CPU with MlpPolicy + self._maybe_recommend_cpu() + + def _maybe_recommend_cpu(self, mlp_class_name: str = "ActorCriticPolicy") -> None: + """ + Recommend to use CPU only when using A2C/PPO with MlpPolicy. + + :param: The name of the class for the default MlpPolicy. + """ + policy_class_name = self.policy_class.__name__ + if self.device != th.device("cpu") and policy_class_name == mlp_class_name: + warnings.warn( + f"You are trying to run {self.__class__.__name__} on the GPU, " + "but it is primarily intended to run on the CPU when not using a CNN policy " + f"(you are using {policy_class_name} which should be a MlpPolicy). " + "See https://github.com/DLR-RM/stable-baselines3/issues/1245 " + "for more info. " + "You can pass `device='cpu'` or `export CUDA_VISIBLE_DEVICES=` to force using the CPU." + "Note: The model will train, but the GPU utilization will be poor and " + "the training might take longer than on CPU.", + UserWarning, + ) def collect_rollouts( self, diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index ee717ba15..852a32b3f 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.4.0a8 +2.4.0a10 diff --git a/tests/test_envs.py b/tests/test_envs.py index 9a61eeef0..2fbce120c 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -123,6 +123,8 @@ def patched_step(_action): spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}), # Non zero start index spaces.Discrete(3, start=-1), + # 2D MultiDiscrete + spaces.MultiDiscrete(np.array([[4, 4], [2, 3]])), # Non zero start index (MultiDiscrete) spaces.MultiDiscrete([4, 4], start=[1, 0]), # Non zero start index inside a Dict diff --git a/tests/test_run.py b/tests/test_run.py index 31c7b956e..4acabb692 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,6 +1,7 @@ import gymnasium as gym import numpy as np import pytest +import torch as th from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.env_util import make_vec_env @@ -211,8 +212,11 @@ def test_warn_dqn_multi_env(): def test_ppo_warnings(): - """Test that PPO warns and errors correctly on - problematic rollout buffer sizes""" + """ + Test that PPO warns and errors correctly on + problematic rollout buffer sizes, + and recommend using CPU. + """ # Only 1 step: advantage normalization will return NaN with pytest.raises(AssertionError): @@ -234,3 +238,9 @@ def test_ppo_warnings(): loss = model.logger.name_to_value["train/loss"] assert loss > 0 assert not np.isnan(loss) # check not nan (since nan does not equal nan) + + with pytest.warns(UserWarning, match="You are trying to run PPO on the GPU"): + model = PPO("MlpPolicy", "Pendulum-v1") + # Pretend to be on the GPU + model.device = th.device("cuda") + model._maybe_recommend_cpu()