Skip to content

Commit

Permalink
MultiDiscrete start value + info
Browse files Browse the repository at this point in the history
Supports start value of MultiDiscrete and now.
Also updated changelog, +1 test_case, removed env_checker warning for usage of startvalue
  • Loading branch information
JoshuaBluem committed Dec 6, 2024
1 parent d01691e commit 9851d0e
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 43 deletions.
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ New Features:

Bug Fixes:
^^^^^^^^^^
- Support for start value in discrete action spaces

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down
30 changes: 0 additions & 30 deletions stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,6 @@ def _is_numpy_array_space(space: spaces.Space) -> bool:
"""
return not isinstance(space, (spaces.Dict, spaces.Tuple))


def _starts_at_zero(space: Union[spaces.Discrete, spaces.MultiDiscrete]) -> bool:
"""
Return False if a (Multi)Discrete space has a non-zero start.
"""
return np.allclose(space.start, np.zeros_like(space.start))


def _check_non_zero_start(space: spaces.Space, space_type: str = "observation", key: str = "") -> None:
"""
:param space: Observation or action space
:param space_type: information about whether it is an observation or action space
(for the warning message)
:param key: When the observation space comes from a Dict space, we pass the
corresponding key to have more precise warning messages. Defaults to "".
"""
if isinstance(space, (spaces.Discrete, spaces.MultiDiscrete)) and not _starts_at_zero(space):
maybe_key = f"(key='{key}')" if key else ""
warnings.warn(
f"{type(space).__name__} {space_type} space {maybe_key} with a non-zero start (start={space.start}) "
"is not supported by Stable-Baselines3. "
f"You can use a wrapper or update your {space_type} space."
)


def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
"""
Check that the input will be compatible with Stable-Baselines
Expand Down Expand Up @@ -87,7 +62,6 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
for key, space in observation_space.spaces.items():
if isinstance(space, spaces.Dict):
nested_dict = True
_check_non_zero_start(space, "observation", key)

if nested_dict:
warnings.warn(
Expand Down Expand Up @@ -115,17 +89,13 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
"which is supported by SB3."
)

_check_non_zero_start(observation_space, "observation")

if isinstance(observation_space, spaces.Sequence):
warnings.warn(
"Sequence observation space is not supported by Stable-Baselines3. "
"You can pad your observation to have a fixed size instead.\n"
"Note: The checks for returned values are skipped."
)

_check_non_zero_start(action_space, "action")

if not _is_numpy_array_space(action_space):
warnings.warn(
"The action space is not based off a numpy array. Typically this means it's either a Dict or Tuple space. "
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def _sample_action(
# We store the scaled action in the buffer
buffer_action = scaled_action
action = self.policy.unscale_action(scaled_action)
elif isinstance(self.action_space, spaces.Discrete):
elif isinstance(self.action_space, (spaces.Discrete, spaces.MultiDiscrete)):
# Discrete case: Shift action values so every action starts from zero
scaled_action = self.policy.scale_action(unscaled_action)

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def collect_rollouts(
# Otherwise, clip the actions to avoid out of bound error
# as we are sampling from an unbounded Gaussian distribution
unscaled_actions = np.clip(actions, self.action_space.low, self.action_space.high)
elif isinstance(self.action_space, spaces.Discrete):
elif isinstance(self.action_space, (spaces.Discrete, spaces.MultiDiscrete)):
# Scale actions to match action-space bounds
unscaled_actions = self.policy.unscale_action(unscaled_actions) # type: ignore[assignment, arg-type]

Expand Down
6 changes: 3 additions & 3 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def predict(
# Actions could be on arbitrary scale, so clip the actions to avoid
# out of bound error (e.g. if sampling from a Gaussian distribution)
actions = np.clip(actions, self.action_space.low, self.action_space.high) # type: ignore[assignment, arg-type]
elif isinstance(self.action_space, spaces.Discrete):
elif isinstance(self.action_space, (spaces.Discrete, spaces.MultiDiscrete)):
# transform action to its action-space bounds starting from its defined start value
actions = self.unscale_action(actions) # type: ignore[assignment, arg-type]

Expand Down Expand Up @@ -406,7 +406,7 @@ def scale_action(self, action: np.ndarray) -> np.ndarray:
# Box case
low, high = self.action_space.low, self.action_space.high
scaled_action = 2.0 * ((action - low) / (high - low)) - 1.0
elif isinstance(self.action_space, spaces.Discrete):
elif isinstance(self.action_space, (spaces.Discrete, spaces.MultiDiscrete)):
# discrete actions case
scaled_action = np.subtract(action, self.action_space.start)
else:
Expand All @@ -429,7 +429,7 @@ def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray:
if isinstance(self.action_space, spaces.Box):
low, high = self.action_space.low, self.action_space.high
unscaled_action = low + (0.5 * (scaled_action + 1.0) * (high - low))
elif isinstance(self.action_space, spaces.Discrete):
elif isinstance(self.action_space, (spaces.Discrete, spaces.MultiDiscrete)):
# match discrete actions bounds
unscaled_action = np.add(scaled_action, self.action_space.start)
else:
Expand Down
35 changes: 27 additions & 8 deletions tests/test_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,20 @@ def __init__(self, nvec):
class ActionBoundsTestClass(DummyEnv):
def step(self, action):
if isinstance(self.action_space, spaces.Discrete):
assert np.all(action >= self.action_space.start), f"Discrete action {action} is below the lower bound {self.action_space.start}"
assert np.all(action <= self.action_space.start + self.action_space.n), f"Discrete action {action} is above the upper bound {self.action_space.start}+{self.action_space.n}"
assert np.all(action >= self.action_space.start), (
f"Discrete action {action} is below the lower bound {self.action_space.start}")
assert np.all(action <= self.action_space.start + self.action_space.n), (
f"Discrete action {action} is above the upper bound {self.action_space.start}+{self.action_space.n}")
elif isinstance(self.action_space, spaces.MultiDiscrete):
assert np.all(action >= self.action_space.start), (
f"MultiDiscrete action {action} is below the lower bound {self.action_space.start}")
assert np.all(action <= self.action_space.start + self.action_space.nvec), (
f"MultiDiscrete action {action} is above the upper bound {self.action_space.start}+{self.action_space.nvec}")
elif isinstance(self.action_space, spaces.Box):
assert np.all(action >= self.action_space.low), f"Action {action} is below the lower bound {self.action_space.low}"
assert np.all(action <= self.action_space.high), f"Action {action} is above the upper bound {self.action_space.high}"
assert np.all(action >= self.action_space.low), (
f"Action {action} is below the lower bound {self.action_space.low}")
assert np.all(action <= self.action_space.high), (
f"Action {action} is above the upper bound {self.action_space.high}")
return self.observation_space.sample(), 0.0, False, False, {}

@pytest.mark.parametrize(
Expand Down Expand Up @@ -184,6 +193,7 @@ def test_float64_action_space(model_class, obs_space, action_space):
[
# on-policy test
(PPO, spaces.Discrete(5, start=-6543)),
(PPO, spaces.MultiDiscrete([4, 3], start=[-6543, 11])),
(PPO, spaces.Box(low=2344, high=2345, shape=(3,), dtype=np.float32)),
# off-policy test
(DQN, spaces.Discrete(2, start=9923)),
Expand All @@ -210,11 +220,20 @@ def test_space_bounds(model_class, action_space):

action, _ = model.predict(initial_obs, deterministic=False)
if isinstance(action_space, spaces.Discrete):
assert np.all(action >= action_space.start), f"Discrete action {action} is below the lower bound {action_space.start}"
assert np.all(action <= action_space.start + action_space.n), f"Discrete action {action} is above the upper bound {action_space.start}+{action_space.n}"
assert np.all(action >= action_space.start), (
f"Discrete action {action} is below the lower bound {action_space.start}")
assert np.all(action <= action_space.start + action_space.n), (
f"Discrete action {action} is above the upper bound {action_space.start}+{action_space.n}")
elif isinstance(action_space, spaces.MultiDiscrete):
assert np.all(action >= action_space.start), (
f"MultiDiscrete action {action} is below the lower bound {action_space.start}")
assert np.all(action <= action_space.start + action_space.nvec), (
f"MultiDiscrete action {action} is above the upper bound {action_space.start}+{action_space.nvec}")
elif isinstance(action_space, spaces.Box):
assert np.all(action >= action_space.low), f"Action {action} is below the lower bound {action_space.low}"
assert np.all(action <= action_space.high), f"Action {action} is above the upper bound {action_space.high}"
assert np.all(action >= action_space.low), (
f"Action {action} is below the lower bound {action_space.low}")
assert np.all(action <= action_space.high), (
f"Action {action} is above the upper bound {action_space.high}")


def test_multidim_binary_not_supported():
Expand Down

0 comments on commit 9851d0e

Please sign in to comment.