Skip to content

Commit

Permalink
Added test cases for start value
Browse files Browse the repository at this point in the history
Added tests to the pylint checks to confirm the correct action bounds of predict and step method
  • Loading branch information
JoshuaBluem committed Dec 6, 2024
1 parent c0aa574 commit d01691e
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions tests/test_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ def __init__(self, nvec):
BOX_SPACE_FLOAT32,
)

class ActionBoundsTestClass(DummyEnv):
def step(self, action):
if isinstance(self.action_space, spaces.Discrete):
assert np.all(action >= self.action_space.start), f"Discrete action {action} is below the lower bound {self.action_space.start}"
assert np.all(action <= self.action_space.start + self.action_space.n), f"Discrete action {action} is above the upper bound {self.action_space.start}+{self.action_space.n}"
elif isinstance(self.action_space, spaces.Box):
assert np.all(action >= self.action_space.low), f"Action {action} is below the lower bound {self.action_space.low}"
assert np.all(action <= self.action_space.high), f"Action {action} is above the upper bound {self.action_space.high}"
return self.observation_space.sample(), 0.0, False, False, {}

@pytest.mark.parametrize(
"env",
Expand Down Expand Up @@ -170,6 +179,44 @@ def test_float64_action_space(model_class, obs_space, action_space):
assert action.dtype == env.action_space.dtype


@pytest.mark.parametrize(
"model_class, action_space",
[
# on-policy test
(PPO, spaces.Discrete(5, start=-6543)),
(PPO, spaces.Box(low=2344, high=2345, shape=(3,), dtype=np.float32)),
# off-policy test
(DQN, spaces.Discrete(2, start=9923)),
(SAC, spaces.Box(low=-123, high=-122, shape=(1,), dtype=np.float32)),
],
)
def test_space_bounds(model_class, action_space):
obs_space = BOX_SPACE_FLOAT32
env = ActionBoundsTestClass(obs_space, action_space)
env = gym.wrappers.TimeLimit(env, max_episode_steps=200)
if isinstance(env.observation_space, spaces.Dict):
policy = "MultiInputPolicy"
else:
policy = "MlpPolicy"

if model_class in [PPO, A2C]:
kwargs = dict(n_steps=64, policy_kwargs=dict(net_arch=[12]))
else:
kwargs = dict(learning_starts=60, policy_kwargs=dict(net_arch=[12]))

model = model_class(policy, env, **kwargs)
model.learn(64)
initial_obs, _ = env.reset()

action, _ = model.predict(initial_obs, deterministic=False)
if isinstance(action_space, spaces.Discrete):
assert np.all(action >= action_space.start), f"Discrete action {action} is below the lower bound {action_space.start}"
assert np.all(action <= action_space.start + action_space.n), f"Discrete action {action} is above the upper bound {action_space.start}+{action_space.n}"
elif isinstance(action_space, spaces.Box):
assert np.all(action >= action_space.low), f"Action {action} is below the lower bound {action_space.low}"
assert np.all(action <= action_space.high), f"Action {action} is above the upper bound {action_space.high}"


def test_multidim_binary_not_supported():
env = DummyEnv(BOX_SPACE_FLOAT32, spaces.MultiBinary([2, 3]))
with pytest.raises(AssertionError, match=r"Multi-dimensional MultiBinary\(.*\) action space is not supported"):
Expand Down

0 comments on commit d01691e

Please sign in to comment.