Skip to content

Commit

Permalink
Add policy documentation links to policy_kwargs parameter (#2050)
Browse files Browse the repository at this point in the history
* docs: Add policy documentation links to policy_kwargs parameter

* Fix missing references, update changelog

---------

Co-authored-by: Antonin RAFFIN <[email protected]>
  • Loading branch information
kplers and araffin authored Dec 2, 2024
1 parent 897d01d commit 9caa168
Show file tree
Hide file tree
Showing 9 changed files with 16 additions and 11 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Documentation:
^^^^^^^^^^^^^^
- Added Decisions and Dragons to resources. (@jmacglashan)
- Updated PyBullet example, now compatible with Gymnasium
- Added link to policies for ``policy_kwargs`` parameter (@kplers)

Release 2.4.0 (2024-11-18)
--------------------------
Expand Down Expand Up @@ -1738,4 +1739,4 @@ And all the contributors:
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean
@brn-dev @jmacglashan
@brn-dev @jmacglashan @kplers
6 changes: 4 additions & 2 deletions docs/modules/a2c.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments.

A2C is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using ``SubprocVecEnv`` instead of the default ``DummyVecEnv``:

.. code-block::
.. code-block:: python
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
Expand All @@ -88,7 +88,7 @@ Train a A2C agent on ``CartPole-v1`` using 4 environments.
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
model = A2C("MlpPolicy", env, device="cpu")
model.learn(total_timesteps=25_000)
For more information, see :ref:`Vectorized Environments <vec_env>`, `Issue #1245 <https://github.com/DLR-RM/stable-baselines3/issues/1245>`_ or the `Multiprocessing notebook <https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb>`_.


Expand Down Expand Up @@ -165,6 +165,8 @@ Parameters
:inherited-members:


.. _a2c_policies:

A2C Policies
-------------

Expand Down
6 changes: 4 additions & 2 deletions docs/modules/ppo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments.

PPO is meant to be run primarily on the CPU, especially when you are not using a CNN. To improve CPU utilization, try turning off the GPU and using ``SubprocVecEnv`` instead of the default ``DummyVecEnv``:

.. code-block::
.. code-block:: python
from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env
Expand All @@ -102,7 +102,7 @@ Train a PPO agent on ``CartPole-v1`` using 4 environments.
env = make_vec_env("CartPole-v1", n_envs=8, vec_env_cls=SubprocVecEnv)
model = PPO("MlpPolicy", env, device="cpu")
model.learn(total_timesteps=25_000)
For more information, see :ref:`Vectorized Environments <vec_env>`, `Issue #1245 <https://github.com/DLR-RM/stable-baselines3/issues/1245#issuecomment-1435766949>`_ or the `Multiprocessing notebook <https://colab.research.google.com/github/Stable-Baselines-Team/rl-colab-notebooks/blob/sb3/multiprocessing_rl.ipynb>`_.

Results
Expand Down Expand Up @@ -178,6 +178,8 @@ Parameters
:inherited-members:


.. _ppo_policies:

PPO Policies
-------------

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class A2C(OnPolicyAlgorithm):
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`a2c_policies`
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class DDPG(TD3):
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`ddpg_policies`
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class DQN(OffPolicyAlgorithm):
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`dqn_policies`
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class PPO(OnPolicyAlgorithm):
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`ppo_policies`
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class SAC(OffPolicyAlgorithm):
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`sac_policies`
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class TD3(OffPolicyAlgorithm):
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param policy_kwargs: additional arguments to be passed to the policy on creation. See :ref:`td3_policies`
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
Expand Down

0 comments on commit 9caa168

Please sign in to comment.