From 3d59b5c86b0d8d61ee4a68cb2ae8743fd178670b Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Thu, 24 Oct 2024 22:20:05 +0900 Subject: [PATCH 1/2] Use uv on GitHub CI for faster download and update changelog (#2026) * Use uv on GitHub CI for faster download and update changelog * Fix new mypy issues --- .github/workflows/ci.yml | 11 +++++++---- docs/guide/sb3_contrib.rst | 1 + docs/misc/changelog.rst | 7 +++++++ stable_baselines3/common/utils.py | 4 ++-- tests/test_utils.py | 2 +- 5 files changed, 18 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 822e0cb3f..cb9055266 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,18 +31,21 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip + # Use uv for faster downloads + pip install uv # cpu version of pytorch - pip install torch==2.3.1 --index-url https://download.pytorch.org/whl/cpu + # See https://github.com/astral-sh/uv/issues/1497 + uv pip install --system torch==2.3.1+cpu --index https://download.pytorch.org/whl/cpu # Install Atari Roms - pip install autorom + uv pip install --system autorom wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64 base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz AutoROM --accept-license --source-file Roms.tar.gz - pip install .[extra_no_roms,tests,docs] + uv pip install --system .[extra_no_roms,tests,docs] # Use headless version - pip install opencv-python-headless + uv pip install --system opencv-python-headless - name: Lint with ruff run: | make lint diff --git a/docs/guide/sb3_contrib.rst b/docs/guide/sb3_contrib.rst index 445832c59..8ec912e15 100644 --- a/docs/guide/sb3_contrib.rst +++ b/docs/guide/sb3_contrib.rst @@ -42,6 +42,7 @@ See documentation for the full list of included features. - `PPO with recurrent policy (RecurrentPPO aka PPO LSTM) `_ - `Truncated Quantile Critics (TQC)`_ - `Trust Region Policy Optimization (TRPO) `_ +- `Batch Normalization in Deep Reinforcement Learning (CrossQ) `_ **Gym Wrappers**: diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index af83d2302..2c0974ac2 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -6,6 +6,8 @@ Changelog Release 2.4.0a10 (WIP) -------------------------- +**New algorithm: CrossQ in SB3 Contrib** + .. note:: DQN (and QR-DQN) models saved with SB3 < 2.4.0 will show a warning about @@ -43,6 +45,10 @@ Bug Fixes: `SB3-Contrib`_ ^^^^^^^^^^^^^^ +- Added ``CrossQ`` algorithm, from "Batch Normalization in Deep Reinforcement Learning" paper (@danielpalen) +- Added ``BatchRenorm`` PyTorch layer used in ``CrossQ`` (@danielpalen) +- Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger) +- Fixed loading QRDQN changes `target_update_interval` (@jak3122) `RL Zoo`_ ^^^^^^^^^ @@ -61,6 +67,7 @@ Others: - Remove unnecessary SDE noise resampling in PPO update (@brn-dev) - Updated PyTorch version on CI to 2.3.1 - Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy`` +- Switched to uv to download packages faster on GitHub CI Bug Fixes: ^^^^^^^^^^ diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index bcde1cfa0..4e9fbc2db 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -46,7 +46,7 @@ def set_random_seed(seed: int, using_cuda: bool = False) -> None: # From stable baselines -def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray: +def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> float: """ Computes fraction of variance that ypred explains about y. Returns 1 - Var[y-ypred] / Var[y] @@ -62,7 +62,7 @@ def explained_variance(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray: """ assert y_true.ndim == 1 and y_pred.ndim == 1 var_y = np.var(y_true) - return np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y + return np.nan if var_y == 0 else float(1 - np.var(y_true - y_pred) / var_y) def update_learning_rate(optimizer: th.optim.Optimizer, learning_rate: float) -> None: diff --git a/tests/test_utils.py b/tests/test_utils.py index 4cc8b7e9f..81f134168 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -177,7 +177,7 @@ def test_custom_vec_env(tmp_path): @pytest.mark.parametrize("direct_policy", [False, True]) -def test_evaluate_policy(direct_policy: bool): +def test_evaluate_policy(direct_policy): model = A2C("MlpPolicy", "Pendulum-v1", seed=0) n_steps_per_episode, n_eval_episodes = 200, 2 From dd3d0acf154dec2b8a9a92fcc5fb83e4a05eaf72 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Tue, 29 Oct 2024 12:23:13 +0100 Subject: [PATCH 2/2] Update readme and clarify planned features (#2030) * Update readme and clarify planned features * Fix rtd python version * Fix pip version for rtd * Update rtd ubuntu and mambaforge * Add upper bound for gymnasium * [ci skip] Update readme --- .readthedocs.yml | 4 ++-- CONTRIBUTING.md | 2 +- README.md | 32 +++++++++++++++++++++----------- docs/conda_env.yml | 12 ++++++------ docs/guide/algos.rst | 1 + docs/index.rst | 4 +++- docs/misc/changelog.rst | 2 ++ 7 files changed, 36 insertions(+), 21 deletions(-) diff --git a/.readthedocs.yml b/.readthedocs.yml index dbb2fad03..26f0c883b 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -16,6 +16,6 @@ conda: environment: docs/conda_env.yml build: - os: ubuntu-22.04 + os: ubuntu-24.04 tools: - python: "mambaforge-22.9" + python: "mambaforge-23.11" diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d295269a9..cc5d1075b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -6,7 +6,7 @@ into two categories: - Create an issue about your intended feature, and we shall discuss the design and implementation. Once we agree that the plan looks good, go ahead and implement it. 2. You want to implement a feature or bug-fix for an outstanding issue - - Look at the outstanding issues here: https://github.com/DLR-RM/stable-baselines3/issues + - Look at the outstanding issues here: https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted - Pick an issue or feature and comment on the task that you want to work on this feature. - If you need more context on a particular issue, please ask, and we shall provide. diff --git a/README.md b/README.md index 52634e486..5d25781d9 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg) -[![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://gitlab.com/araffin/stable-baselines3/-/commits/master) +[![CI](https://github.com/DLR-RM/stable-baselines3/workflows/CI/badge.svg)](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml) +[![Documentation Status](https://readthedocs.org/projects/stable-baselines/badge/?version=master)](https://stable-baselines3.readthedocs.io/en/master/?badge=master) [![coverage report](https://gitlab.com/araffin/stable-baselines3/badges/master/coverage.svg)](https://github.com/DLR-RM/stable-baselines3/actions/workflows/ci.yml) [![codestyle](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) @@ -22,6 +22,8 @@ These algorithms will make it easier for the research community and industry to **The performance of each algorithm was tested** (see *Results* section in their respective page), you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselines3/issues/48) and [#49](https://github.com/DLR-RM/stable-baselines3/issues/49) for more details. +We also provide detailed logs and reports on the [OpenRL Benchmark](https://wandb.ai/openrlbenchmark/sb3) platform. + | **Features** | **Stable-Baselines3** | | --------------------------- | ----------------------| @@ -41,7 +43,13 @@ you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselin ### Planned features -Please take a look at the [Roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) and [Milestones](https://github.com/DLR-RM/stable-baselines3/milestones). +Since most of the features from the [original roadmap](https://github.com/DLR-RM/stable-baselines3/issues/1) have been implemented, there are no major changes planned for SB3, it is now *stable*. +If you want to contribute, you can search in the issues for the ones where [help is welcomed](https://github.com/DLR-RM/stable-baselines3/labels/help%20wanted) and the other [proposed enhancements](https://github.com/DLR-RM/stable-baselines3/labels/enhancement). + +While SB3 development is now focused on bug fixes and maintenance (doc update, user experience, ...), there is more active development going on in the associated repositories: +- newer algorithms are regularly added to the [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) repository +- faster variants are developed in the [SBX (SB3 + Jax)](https://github.com/araffin/sbx) repository +- the training framework for SB3, the RL Zoo, has an active [roadmap](https://github.com/DLR-RM/rl-baselines3-zoo/issues/299) ## Migration guide: from Stable-Baselines (SB2) to Stable-Baselines3 (SB3) @@ -79,7 +87,7 @@ Documentation: https://rl-baselines3-zoo.readthedocs.io/en/master/ We implement experimental features in a separate contrib repository: [SB3-Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) -This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO). +This allows SB3 to maintain a stable and compact core, while still providing the latest features, like Recurrent PPO (PPO LSTM), CrossQ, Truncated Quantile Critics (TQC), Quantile Regression DQN (QR-DQN) or PPO with invalid action masking (Maskable PPO). Documentation is available online: [https://sb3-contrib.readthedocs.io/](https://sb3-contrib.readthedocs.io/) @@ -97,17 +105,16 @@ It provides a minimal number of features compared to SB3 but can be much faster ### Prerequisites Stable Baselines3 requires Python 3.8+. -#### Windows 10 +#### Windows To install stable-baselines on Windows, please look at the [documentation](https://stable-baselines3.readthedocs.io/en/master/guide/install.html#prerequisites). ### Install using pip Install the Stable Baselines3 package: +```sh +pip install 'stable-baselines3[extra]' ``` -pip install stable-baselines3[extra] -``` -**Note:** Some shells such as Zsh require quotation marks around brackets, i.e. `pip install 'stable-baselines3[extra]'` ([More Info](https://stackoverflow.com/a/30539963)). This includes an optional dependencies like Tensorboard, OpenCV or `ale-py` to train on atari games. If you do not need those, you can use: ```sh @@ -177,6 +184,7 @@ All the following examples can be executed online using Google Colab notebooks: | ------------------- | ------------------ | ------------------ | ------------------ | ------------------- | ------------------ | --------------------------------- | | ARS[1](#f1) | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | | A2C | :x: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | +| CrossQ[1](#f1) | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | | DDPG | :x: | :heavy_check_mark: | :x: | :x: | :x: | :heavy_check_mark: | | DQN | :x: | :x: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | | HER | :x: | :heavy_check_mark: | :heavy_check_mark: | :x: | :x: | :heavy_check_mark: | @@ -191,7 +199,7 @@ All the following examples can be executed online using Google Colab notebooks: 1: Implemented in [SB3 Contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib) GitHub repository. -Actions `gym.spaces`: +Actions `gymnasium.spaces`: * `Box`: A N-dimensional box that contains every point in the action space. * `Discrete`: A list of possible actions, where each timestep only one of the actions can be used. * `MultiDiscrete`: A list of possible actions, where each timestep only one action of each discrete set can be used. @@ -218,9 +226,9 @@ To run a single test: python3 -m pytest -v -k 'test_check_env_dict_action' ``` -You can also do a static type check using `pytype` and `mypy`: +You can also do a static type check using `mypy`: ```sh -pip install pytype mypy +pip install mypy make type ``` @@ -252,6 +260,8 @@ To cite this repository in publications: } ``` +Note: If you need to refer to a specific version of SB3, you can also use the [Zenodo DOI](https://doi.org/10.5281/zenodo.8123988). + ## Maintainers Stable-Baselines3 is currently maintained by [Ashley Hill](https://github.com/hill-a) (aka @hill-a), [Antonin Raffin](https://araffin.github.io/) (aka [@araffin](https://github.com/araffin)), [Maximilian Ernestus](https://github.com/ernestum) (aka @ernestum), [Adam Gleave](https://github.com/adamgleave) (@AdamGleave), [Anssi Kanervisto](https://github.com/Miffyli) (@Miffyli) and [Quentin Gallouédec](https://gallouedec.com/) (@qgallouedec). diff --git a/docs/conda_env.yml b/docs/conda_env.yml index 53fecf278..e025a57e1 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -1,18 +1,18 @@ name: root channels: - pytorch - - defaults + - conda-forge dependencies: - cpuonly=1.0=0 - - pip=22.3.1 - - python=3.8 - - pytorch=1.13.0=py3.8_cpu_0 + - pip=24.2 + - python=3.11 + - pytorch=2.5.0=py3.11_cpu_0 - pip: - - gymnasium + - gymnasium>=0.28.1,<0.30 - cloudpickle - opencv-python-headless - pandas - - numpy + - numpy>=1.20,<2.0 - matplotlib - sphinx>=5,<8 - sphinx_rtd_theme>=1.3.0 diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index d5e7ae1d2..db03ba292 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -10,6 +10,7 @@ Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` =================== =========== ============ ================= =============== ================ ARS [#f1]_ ✔️ ✔️ ❌ ❌ ✔️ A2C ✔️ ✔️ ✔️ ✔️ ✔️ +CrossQ [#f1]_ ✔️ ❌ ❌ ❌ ✔️ DDPG ✔️ ❌ ❌ ❌ ✔️ DQN ❌ ✔️ ❌ ❌ ✔️ HER ✔️ ✔️ ❌ ❌ ✔️ diff --git a/docs/index.rst b/docs/index.rst index c8a70a94b..d74120c41 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -113,12 +113,14 @@ To cite this project in publications: url = {http://jmlr.org/papers/v22/20-1364.html} } +Note: If you need to refer to a specific version of SB3, you can also use the `Zenodo DOI `_. + Contributing ------------ To any interested in making the rl baselines better, there are still some improvements that need to be done. -You can check issues in the `repo `_. +You can check issues in the `repository `_. If you want to contribute, please read `CONTRIBUTING.md `_ first. diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 2c0974ac2..b32cd7ce1 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -68,6 +68,7 @@ Others: - Updated PyTorch version on CI to 2.3.1 - Added a warning to recommend using CPU with on policy algorithms (A2C/PPO) and ``MlpPolicy`` - Switched to uv to download packages faster on GitHub CI +- Updated dependencies for read the doc Bug Fixes: ^^^^^^^^^^ @@ -75,6 +76,7 @@ Bug Fixes: Documentation: ^^^^^^^^^^^^^^ - Updated PPO doc to recommend using CPU with ``MlpPolicy`` +- Clarified documentation about planned features and citing software Release 2.3.2 (2024-04-27) --------------------------