From 9caa168686342ffc358c4acc7fbd842fc5fc8aac Mon Sep 17 00:00:00 2001 From: kplers <94663929+kplers@users.noreply.github.com> Date: Mon, 2 Dec 2024 22:40:05 +0900 Subject: [PATCH] Add policy documentation links to policy_kwargs parameter (#2050) * docs: Add policy documentation links to policy_kwargs parameter * Fix missing references, update changelog --------- Co-authored-by: Antonin RAFFIN --- docs/misc/changelog.rst | 3 ++- docs/modules/a2c.rst | 6 ++++-- docs/modules/ppo.rst | 6 ++++-- stable_baselines3/a2c/a2c.py | 2 +- stable_baselines3/ddpg/ddpg.py | 2 +- stable_baselines3/dqn/dqn.py | 2 +- stable_baselines3/ppo/ppo.py | 2 +- stable_baselines3/sac/sac.py | 2 +- stable_baselines3/td3/td3.py | 2 +- 9 files changed, 16 insertions(+), 11 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 91386170e..1ccfbb5a1 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -38,6 +38,7 @@ Documentation: ^^^^^^^^^^^^^^ - Added Decisions and Dragons to resources. (@jmacglashan) - Updated PyBullet example, now compatible with Gymnasium +- Added link to policies for ``policy_kwargs`` parameter (@kplers) Release 2.4.0 (2024-11-18) -------------------------- @@ -1738,4 +1739,4 @@ And all the contributors: @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto @lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger @marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean -@brn-dev @jmacglashan +@brn-dev @jmacglashan @kplers diff --git a/docs/modules/a2c.rst b/docs/modules/a2c.rst index 08e02223b..8f757af86 100644 --- a/docs/modules/a2c.rst +++ b/docs/modules/a2c.rst @@ -78,7 +78,7 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments. A2C is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using ``SubprocVecEnv`` instead of the default ``DummyVecEnv``: - .. code-block:: + .. code-block:: python from stable_baselines3 import A2C from stable_baselines3.common.env_util import make_vec_env @@ -88,7 +88,7 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments. env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv) model = A2C("MlpPolicy", env, device="cpu") model.learn(total_timesteps=25_000) - + For more information, see :ref:`Vectorized Environments `, `Issue #1245 `_ or the `Multiprocessing notebook `_. @@ -165,6 +165,8 @@ Parameters :inherited-members: +.. _a2c_policies: + A2C Policies ------------- diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst index 4285cfb50..8e85b9ffb 100644 --- a/docs/modules/ppo.rst +++ b/docs/modules/ppo.rst @@ -92,7 +92,7 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments. PPO is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using ``SubprocVecEnv`` instead of the default ``DummyVecEnv``: - .. code-block:: + .. code-block:: python from stable_baselines3 import PPO from stable_baselines3.common.env_util import make_vec_env @@ -102,7 +102,7 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments. env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv) model = PPO("MlpPolicy", env, device="cpu") model.learn(total_timesteps=25_000) - + For more information, see :ref:`Vectorized Environments `, `Issue #1245 `_ or the `Multiprocessing notebook `_. Results @@ -178,6 +178,8 @@ Parameters :inherited-members: +.. _ppo_policies: + PPO Policies ------------- diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index a125aaef6..c3388bf8b 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -48,7 +48,7 @@ class A2C(OnPolicyAlgorithm): :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) - :param policy_kwargs: additional arguments to be passed to the policy on creation + :param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`a2c_policies` :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages :param seed: Seed for the pseudo random generators diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index d94fa1812..4af502469 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -44,7 +44,7 @@ class DDPG(TD3): :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 - :param policy_kwargs: additional arguments to be passed to the policy on creation + :param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`ddpg_policies` :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages :param seed: Seed for the pseudo random generators diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index a3f200e59..eb85fd65b 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -53,7 +53,7 @@ class DQN(OffPolicyAlgorithm): :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) - :param policy_kwargs: additional arguments to be passed to the policy on creation + :param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`dqn_policies` :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages :param seed: Seed for the pseudo random generators diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 03cbc2464..7ed1b4bbc 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -62,7 +62,7 @@ class PPO(OnPolicyAlgorithm): :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) - :param policy_kwargs: additional arguments to be passed to the policy on creation + :param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`ppo_policies` :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages :param seed: Seed for the pseudo random generators diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 8cb2ae53d..14a948fbd 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -68,7 +68,7 @@ class SAC(OffPolicyAlgorithm): :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) - :param policy_kwargs: additional arguments to be passed to the policy on creation + :param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`sac_policies` :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages :param seed: Seed for the pseudo random generators diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index affb9c9f8..2dc40b4ba 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -56,7 +56,7 @@ class TD3(OffPolicyAlgorithm): :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average the reported success rate, mean episode length, and mean reward over :param tensorboard_log: the log location for tensorboard (if None, no logging) - :param policy_kwargs: additional arguments to be passed to the policy on creation + :param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`td3_policies` :param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for debug messages :param seed: Seed for the pseudo random generators