Skip to content

Commit

Permalink
Formalities
Browse files Browse the repository at this point in the history
reformatting
  • Loading branch information
JoshuaBluem committed Dec 6, 2024
1 parent 7eb6632 commit e93a43c
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 47 deletions.
2 changes: 1 addition & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1740,4 +1740,4 @@ And all the contributors:
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean
@brn-dev @jmacglashan @kplers
@brn-dev @jmacglashan @kplers @JoshuaBluem
3 changes: 2 additions & 1 deletion stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def _is_numpy_array_space(space: spaces.Space) -> bool:
"""
return not isinstance(space, (spaces.Dict, spaces.Tuple))


def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
"""
Check that the input will be compatible with Stable-Baselines
Expand Down Expand Up @@ -59,7 +60,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act

if isinstance(observation_space, spaces.Dict):
nested_dict = False
for key, space in observation_space.spaces.items():
for _key, space in observation_space.spaces.items():
if isinstance(space, spaces.Dict):
nested_dict = True

Expand Down
10 changes: 0 additions & 10 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,8 @@ def patched_step(_action):
spaces.Dict({"position": spaces.Dict({"abs": spaces.Discrete(5), "rel": spaces.Discrete(2)})}),
# Small image inside a dict
spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}),
# Non zero start index
spaces.Discrete(3, start=-1),
# 2D MultiDiscrete
spaces.MultiDiscrete(np.array([[4, 4], [2, 3]])),
# Non zero start index (MultiDiscrete)
spaces.MultiDiscrete([4, 4], start=[1, 0]),
# Non zero start index inside a Dict
spaces.Dict({"obs": spaces.Discrete(3, start=1)}),
],
)
def test_non_default_spaces(new_obs_space):
Expand Down Expand Up @@ -166,10 +160,6 @@ def patched_step(_action):
spaces.Box(low=-np.inf, high=1, shape=(2,), dtype=np.float32),
# Almost good, except for one dim
spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32),
# Non zero start index
spaces.Discrete(3, start=-1),
# Non zero start index (MultiDiscrete)
spaces.MultiDiscrete([4, 4], start=[1, 0]),
],
)
def test_non_default_action_spaces(new_action_space):
Expand Down
76 changes: 41 additions & 35 deletions tests/test_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,31 @@ def __init__(self, nvec):
BOX_SPACE_FLOAT32,
)


class ActionBoundsTestClass(DummyEnv):
def step(self, action):
if isinstance(self.action_space, spaces.Discrete):
assert np.all(action >= self.action_space.start), (
f"Discrete action {action} is below the lower bound {self.action_space.start}")
assert np.all(action <= self.action_space.start + self.action_space.n), (
f"Discrete action {action} is above the upper bound {self.action_space.start}+{self.action_space.n}")
assert np.all(
action >= self.action_space.start
), f"Discrete action {action} is below the lower bound {self.action_space.start}"
assert np.all(
action <= self.action_space.start + self.action_space.n
), f"Discrete action {action} is above the upper bound {self.action_space.start}+{self.action_space.n}"
elif isinstance(self.action_space, spaces.MultiDiscrete):
assert np.all(action >= self.action_space.start), (
f"MultiDiscrete action {action} is below the lower bound {self.action_space.start}")
assert np.all(action <= self.action_space.start + self.action_space.nvec), (
f"MultiDiscrete action {action} is above the upper bound {self.action_space.start}+{self.action_space.nvec}")
assert np.all(
action >= self.action_space.start
), f"MultiDiscrete action {action} is below the lower bound {self.action_space.start}"
assert np.all(
action <= self.action_space.start + self.action_space.nvec
), f"MultiDiscrete action {action} is above the upper bound {self.action_space.start}+{self.action_space.nvec}"
elif isinstance(self.action_space, spaces.Box):
assert np.all(action >= self.action_space.low), (
f"Action {action} is below the lower bound {self.action_space.low}")
assert np.all(action <= self.action_space.high), (
f"Action {action} is above the upper bound {self.action_space.high}")
assert np.all(action >= self.action_space.low), f"Action {action} is below the lower bound {self.action_space.low}"
assert np.all(
action <= self.action_space.high
), f"Action {action} is above the upper bound {self.action_space.high}"
return self.observation_space.sample(), 0.0, False, False, {}


@pytest.mark.parametrize(
"env",
[
Expand Down Expand Up @@ -189,17 +195,17 @@ def test_float64_action_space(model_class, obs_space, action_space):


@pytest.mark.parametrize(
"model_class, action_space",
[
# on-policy test
(PPO, spaces.Discrete(5, start=-6543)),
(PPO, spaces.MultiDiscrete([4, 3], start=[-6543, 11])),
(PPO, spaces.Box(low=2344, high=2345, shape=(3,), dtype=np.float32)),
# off-policy test
(DQN, spaces.Discrete(2, start=9923)),
(SAC, spaces.Box(low=-123, high=-122, shape=(1,), dtype=np.float32)),
],
)
"model_class, action_space",
[
# on-policy test
(PPO, spaces.Discrete(5, start=-6543)),
(PPO, spaces.MultiDiscrete([4, 3], start=[-6543, 11])),
(PPO, spaces.Box(low=2344, high=2345, shape=(3,), dtype=np.float32)),
# off-policy test
(DQN, spaces.Discrete(2, start=9923)),
(SAC, spaces.Box(low=-123, high=-122, shape=(1,), dtype=np.float32)),
],
)
def test_space_bounds(model_class, action_space):
obs_space = BOX_SPACE_FLOAT32
env = ActionBoundsTestClass(obs_space, action_space)
Expand All @@ -220,20 +226,20 @@ def test_space_bounds(model_class, action_space):

action, _ = model.predict(initial_obs, deterministic=False)
if isinstance(action_space, spaces.Discrete):
assert np.all(action >= action_space.start), (
f"Discrete action {action} is below the lower bound {action_space.start}")
assert np.all(action <= action_space.start + action_space.n), (
f"Discrete action {action} is above the upper bound {action_space.start}+{action_space.n}")
assert np.all(action >= action_space.start), f"Discrete action {action} is below the lower bound {action_space.start}"
assert np.all(
action <= action_space.start + action_space.n
), f"Discrete action {action} is above the upper bound {action_space.start}+{action_space.n}"
elif isinstance(action_space, spaces.MultiDiscrete):
assert np.all(action >= action_space.start), (
f"MultiDiscrete action {action} is below the lower bound {action_space.start}")
assert np.all(action <= action_space.start + action_space.nvec), (
f"MultiDiscrete action {action} is above the upper bound {action_space.start}+{action_space.nvec}")
assert np.all(
action >= action_space.start
), f"MultiDiscrete action {action} is below the lower bound {action_space.start}"
assert np.all(
action <= action_space.start + action_space.nvec
), f"MultiDiscrete action {action} is above the upper bound {action_space.start}+{action_space.nvec}"
elif isinstance(action_space, spaces.Box):
assert np.all(action >= action_space.low), (
f"Action {action} is below the lower bound {action_space.low}")
assert np.all(action <= action_space.high), (
f"Action {action} is above the upper bound {action_space.high}")
assert np.all(action >= action_space.low), f"Action {action} is below the lower bound {action_space.low}"
assert np.all(action <= action_space.high), f"Action {action} is above the upper bound {action_space.high}"


def test_multidim_binary_not_supported():
Expand Down

0 comments on commit e93a43c

Please sign in to comment.