From e93a43cbeb763ce93cc654002b53aadf097dc9bc Mon Sep 17 00:00:00 2001 From: JoshuaB <72627997+JoshuaBluem@users.noreply.github.com> Date: Fri, 6 Dec 2024 09:35:19 +0100 Subject: [PATCH] Formalities reformatting --- docs/misc/changelog.rst | 2 +- stable_baselines3/common/env_checker.py | 3 +- tests/test_envs.py | 10 ---- tests/test_spaces.py | 76 +++++++++++++------------ 4 files changed, 44 insertions(+), 47 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 0fab8b56c..2af50ab6d 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -1740,4 +1740,4 @@ And all the contributors: @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto @lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger @marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean -@brn-dev @jmacglashan @kplers +@brn-dev @jmacglashan @kplers @JoshuaBluem diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 75b79cef8..a43fced0b 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -16,6 +16,7 @@ def _is_numpy_array_space(space: spaces.Space) -> bool: """ return not isinstance(space, (spaces.Dict, spaces.Tuple)) + def _check_image_input(observation_space: spaces.Box, key: str = "") -> None: """ Check that the input will be compatible with Stable-Baselines @@ -59,7 +60,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act if isinstance(observation_space, spaces.Dict): nested_dict = False - for key, space in observation_space.spaces.items(): + for _key, space in observation_space.spaces.items(): if isinstance(space, spaces.Dict): nested_dict = True diff --git a/tests/test_envs.py b/tests/test_envs.py index 2fbce120c..293a188fb 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -121,14 +121,8 @@ def patched_step(_action): spaces.Dict({"position": spaces.Dict({"abs": spaces.Discrete(5), "rel": spaces.Discrete(2)})}), # Small image inside a dict spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}), - # Non zero start index - spaces.Discrete(3, start=-1), # 2D MultiDiscrete spaces.MultiDiscrete(np.array([[4, 4], [2, 3]])), - # Non zero start index (MultiDiscrete) - spaces.MultiDiscrete([4, 4], start=[1, 0]), - # Non zero start index inside a Dict - spaces.Dict({"obs": spaces.Discrete(3, start=1)}), ], ) def test_non_default_spaces(new_obs_space): @@ -166,10 +160,6 @@ def patched_step(_action): spaces.Box(low=-np.inf, high=1, shape=(2,), dtype=np.float32), # Almost good, except for one dim spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32), - # Non zero start index - spaces.Discrete(3, start=-1), - # Non zero start index (MultiDiscrete) - spaces.MultiDiscrete([4, 4], start=[1, 0]), ], ) def test_non_default_action_spaces(new_action_space): diff --git a/tests/test_spaces.py b/tests/test_spaces.py index 1e3b2dcb4..59b546da6 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -53,25 +53,31 @@ def __init__(self, nvec): BOX_SPACE_FLOAT32, ) + class ActionBoundsTestClass(DummyEnv): def step(self, action): if isinstance(self.action_space, spaces.Discrete): - assert np.all(action >= self.action_space.start), ( - f"Discrete action {action} is below the lower bound {self.action_space.start}") - assert np.all(action <= self.action_space.start + self.action_space.n), ( - f"Discrete action {action} is above the upper bound {self.action_space.start}+{self.action_space.n}") + assert np.all( + action >= self.action_space.start + ), f"Discrete action {action} is below the lower bound {self.action_space.start}" + assert np.all( + action <= self.action_space.start + self.action_space.n + ), f"Discrete action {action} is above the upper bound {self.action_space.start}+{self.action_space.n}" elif isinstance(self.action_space, spaces.MultiDiscrete): - assert np.all(action >= self.action_space.start), ( - f"MultiDiscrete action {action} is below the lower bound {self.action_space.start}") - assert np.all(action <= self.action_space.start + self.action_space.nvec), ( - f"MultiDiscrete action {action} is above the upper bound {self.action_space.start}+{self.action_space.nvec}") + assert np.all( + action >= self.action_space.start + ), f"MultiDiscrete action {action} is below the lower bound {self.action_space.start}" + assert np.all( + action <= self.action_space.start + self.action_space.nvec + ), f"MultiDiscrete action {action} is above the upper bound {self.action_space.start}+{self.action_space.nvec}" elif isinstance(self.action_space, spaces.Box): - assert np.all(action >= self.action_space.low), ( - f"Action {action} is below the lower bound {self.action_space.low}") - assert np.all(action <= self.action_space.high), ( - f"Action {action} is above the upper bound {self.action_space.high}") + assert np.all(action >= self.action_space.low), f"Action {action} is below the lower bound {self.action_space.low}" + assert np.all( + action <= self.action_space.high + ), f"Action {action} is above the upper bound {self.action_space.high}" return self.observation_space.sample(), 0.0, False, False, {} + @pytest.mark.parametrize( "env", [ @@ -189,17 +195,17 @@ def test_float64_action_space(model_class, obs_space, action_space): @pytest.mark.parametrize( - "model_class, action_space", - [ - # on-policy test - (PPO, spaces.Discrete(5, start=-6543)), - (PPO, spaces.MultiDiscrete([4, 3], start=[-6543, 11])), - (PPO, spaces.Box(low=2344, high=2345, shape=(3,), dtype=np.float32)), - # off-policy test - (DQN, spaces.Discrete(2, start=9923)), - (SAC, spaces.Box(low=-123, high=-122, shape=(1,), dtype=np.float32)), - ], - ) + "model_class, action_space", + [ + # on-policy test + (PPO, spaces.Discrete(5, start=-6543)), + (PPO, spaces.MultiDiscrete([4, 3], start=[-6543, 11])), + (PPO, spaces.Box(low=2344, high=2345, shape=(3,), dtype=np.float32)), + # off-policy test + (DQN, spaces.Discrete(2, start=9923)), + (SAC, spaces.Box(low=-123, high=-122, shape=(1,), dtype=np.float32)), + ], +) def test_space_bounds(model_class, action_space): obs_space = BOX_SPACE_FLOAT32 env = ActionBoundsTestClass(obs_space, action_space) @@ -220,20 +226,20 @@ def test_space_bounds(model_class, action_space): action, _ = model.predict(initial_obs, deterministic=False) if isinstance(action_space, spaces.Discrete): - assert np.all(action >= action_space.start), ( - f"Discrete action {action} is below the lower bound {action_space.start}") - assert np.all(action <= action_space.start + action_space.n), ( - f"Discrete action {action} is above the upper bound {action_space.start}+{action_space.n}") + assert np.all(action >= action_space.start), f"Discrete action {action} is below the lower bound {action_space.start}" + assert np.all( + action <= action_space.start + action_space.n + ), f"Discrete action {action} is above the upper bound {action_space.start}+{action_space.n}" elif isinstance(action_space, spaces.MultiDiscrete): - assert np.all(action >= action_space.start), ( - f"MultiDiscrete action {action} is below the lower bound {action_space.start}") - assert np.all(action <= action_space.start + action_space.nvec), ( - f"MultiDiscrete action {action} is above the upper bound {action_space.start}+{action_space.nvec}") + assert np.all( + action >= action_space.start + ), f"MultiDiscrete action {action} is below the lower bound {action_space.start}" + assert np.all( + action <= action_space.start + action_space.nvec + ), f"MultiDiscrete action {action} is above the upper bound {action_space.start}+{action_space.nvec}" elif isinstance(action_space, spaces.Box): - assert np.all(action >= action_space.low), ( - f"Action {action} is below the lower bound {action_space.low}") - assert np.all(action <= action_space.high), ( - f"Action {action} is above the upper bound {action_space.high}") + assert np.all(action >= action_space.low), f"Action {action} is below the lower bound {action_space.low}" + assert np.all(action <= action_space.high), f"Action {action} is above the upper bound {action_space.high}" def test_multidim_binary_not_supported():