Skip to content

Commit

Permalink
Add rollout_buffer_class parameter to on-policy algorithms (#1720)
Browse files Browse the repository at this point in the history
* Add rollout_buffer_class and rollout_buffer_kwargs parameters to OnPolicyAlgorithm

* Add rollout_buffer_class and rollout_buffer_kwargs to PPO.

* Add rollout_buffer_class and rollout_buffer_kwargs to A2C.

* Make use of the rollout buffer kwargs.

* Update version

* Add test and update doc

---------

Co-authored-by: Antonin Raffin <[email protected]>
  • Loading branch information
ernestum and araffin authored Oct 27, 2023
1 parent f56dded commit 69afefc
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 8 deletions.
2 changes: 1 addition & 1 deletion docs/guide/vec_envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ SB3 VecEnv API is actually close to Gym 0.21 API but differs to Gym 0.26+ API:
Note that if ``render_mode != "rgb_array"``, you can only call ``vec_env.render()`` (without argument or with ``mode=env.render_mode``).

- the ``reset()`` method doesn't take any parameter. If you want to seed the pseudo-random generator or pass options,
you should call ``vec_env.seed(seed=seed)``/``vec_env.set_options(options)`` and ``obs = vec_env.reset()`` afterward (seed and options are discared after each call to ``reset()``).
you should call ``vec_env.seed(seed=seed)``/``vec_env.set_options(options)`` and ``obs = vec_env.reset()`` afterward (seed and options are discarded after each call to ``reset()``).

- methods and attributes of the underlying Gym envs can be accessed, called and set using ``vec_env.get_attr("attribute_name")``,
``vec_env.env_method("method_name", args1, args2, kwargs1=kwargs1)`` and ``vec_env.set_attr("attribute_name", new_value)``.
Expand Down
11 changes: 7 additions & 4 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
Changelog
==========

Release 2.2.0a8 (WIP)
Release 2.2.0a9 (WIP)
--------------------------
**Support for options at reset, bug fixes and better error messages**

Breaking Changes:
^^^^^^^^^^^^^^^^^
Expand All @@ -16,6 +17,8 @@ New Features:
- Improved error message of the ``env_checker`` for env wrongly detected as GoalEnv (``compute_reward()`` is defined)
- Improved error message when mixing Gym API with VecEnv API (see GH#1694)
- Add support for setting ``options`` at reset with VecEnv via the ``set_options()`` method. Same as seeds logic, options are reset at the end of an episode (@ReHoss)
- Added ``rollout_buffer_class`` and ``rollout_buffer_kwargs`` arguments to on-policy algorithms (A2C and PPO)


Bug Fixes:
^^^^^^^^^^
Expand All @@ -36,9 +39,9 @@ Bug Fixes:
`RL Zoo`_
^^^^^^^^^

`SBX`_
^^^^^^^^^
- Added ``DDPG`` and ``TD3``
`SBX`_ (SB3 + Jax)
^^^^^^^^^^^^^^^^^^
- Added ``DDPG`` and ``TD3`` algorithms

Deprecations:
^^^^^^^^^^^^^
Expand Down
7 changes: 7 additions & 0 deletions stable_baselines3/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from gymnasium import spaces
from torch.nn import functional as F

from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
Expand Down Expand Up @@ -41,6 +42,8 @@ class A2C(OnPolicyAlgorithm):
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation.
:param normalize_advantage: Whether to normalize or not the advantage
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
Expand Down Expand Up @@ -75,6 +78,8 @@ def __init__(
use_rms_prop: bool = True,
use_sde: bool = False,
sde_sample_freq: int = -1,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
normalize_advantage: bool = False,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
Expand All @@ -96,6 +101,8 @@ def __init__(
max_grad_norm=max_grad_norm,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
rollout_buffer_class=rollout_buffer_class,
rollout_buffer_kwargs=rollout_buffer_kwargs,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
policy_kwargs=policy_kwargs,
Expand Down
15 changes: 13 additions & 2 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class OnPolicyAlgorithm(BaseAlgorithm):
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation.
:param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average
the reported success rate, mean episode length, and mean reward over
:param tensorboard_log: the log location for tensorboard (if None, no logging)
Expand Down Expand Up @@ -68,6 +70,8 @@ def __init__(
max_grad_norm: float,
use_sde: bool,
sde_sample_freq: int,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
monitor_wrapper: bool = True,
Expand Down Expand Up @@ -100,6 +104,8 @@ def __init__(
self.ent_coef = ent_coef
self.vf_coef = vf_coef
self.max_grad_norm = max_grad_norm
self.rollout_buffer_class = rollout_buffer_class
self.rollout_buffer_kwargs = rollout_buffer_kwargs or {}

if _init_setup_model:
self._setup_model()
Expand All @@ -108,16 +114,21 @@ def _setup_model(self) -> None:
self._setup_lr_schedule()
self.set_random_seed(self.seed)

buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RolloutBuffer
if self.rollout_buffer_class is None:
if isinstance(self.observation_space, spaces.Dict):
self.rollout_buffer_class = DictRolloutBuffer
else:
self.rollout_buffer_class = RolloutBuffer

self.rollout_buffer = buffer_cls(
self.rollout_buffer = self.rollout_buffer_class(
self.n_steps,
self.observation_space, # type: ignore[arg-type]
self.action_space,
device=self.device,
gamma=self.gamma,
gae_lambda=self.gae_lambda,
n_envs=self.n_envs,
**self.rollout_buffer_kwargs,
)
self.policy = self.policy_class( # type: ignore[assignment]
self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs
Expand Down
7 changes: 7 additions & 0 deletions stable_baselines3/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from gymnasium import spaces
from torch.nn import functional as F

from stable_baselines3.common.buffers import RolloutBuffer
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, BasePolicy, MultiInputActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
Expand Down Expand Up @@ -52,6 +53,8 @@ class PPO(OnPolicyAlgorithm):
instead of action noise exploration (default: False)
:param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE
Default: -1 (only sample at the beginning of the rollout)
:param rollout_buffer_class: Rollout buffer class to use. If ``None``, it will be automatically selected.
:param rollout_buffer_kwargs: Keyword arguments to pass to the rollout buffer on creation
:param target_kl: Limit the KL divergence between updates,
because the clipping is not enough to prevent large update
see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213)
Expand Down Expand Up @@ -92,6 +95,8 @@ def __init__(
max_grad_norm: float = 0.5,
use_sde: bool = False,
sde_sample_freq: int = -1,
rollout_buffer_class: Optional[Type[RolloutBuffer]] = None,
rollout_buffer_kwargs: Optional[Dict[str, Any]] = None,
target_kl: Optional[float] = None,
stats_window_size: int = 100,
tensorboard_log: Optional[str] = None,
Expand All @@ -113,6 +118,8 @@ def __init__(
max_grad_norm=max_grad_norm,
use_sde=use_sde,
sde_sample_freq=sde_sample_freq,
rollout_buffer_class=rollout_buffer_class,
rollout_buffer_kwargs=rollout_buffer_kwargs,
stats_window_size=stats_window_size,
tensorboard_log=tensorboard_log,
policy_kwargs=policy_kwargs,
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.2.0a8
2.2.0a9
14 changes: 14 additions & 0 deletions tests/test_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch as th
from gymnasium import spaces

from stable_baselines3 import A2C
from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
Expand Down Expand Up @@ -150,3 +151,16 @@ def test_device_buffer(replay_buffer_cls, device):
assert value[key].device.type == desired_device
elif isinstance(value, th.Tensor):
assert value.device.type == desired_device


def test_custom_rollout_buffer():
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict())

with pytest.raises(TypeError, match="unexpected keyword argument 'wrong_keyword'"):
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict(wrong_keyword=1))

with pytest.raises(TypeError, match="got multiple values for keyword argument 'gamma'"):
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=RolloutBuffer, rollout_buffer_kwargs=dict(gamma=1))

with pytest.raises(AssertionError, match="DictRolloutBuffer must be used with Dict obs space only"):
A2C("MlpPolicy", "Pendulum-v1", rollout_buffer_class=DictRolloutBuffer)

0 comments on commit 69afefc

Please sign in to comment.