Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentpierre authored Sep 8, 2020
1 parent 393808b commit f591cfa
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
2 changes: 1 addition & 1 deletion gym-unity/gym_unity/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __init__(
# Set action spaces
if self.group_spec.is_action_discrete():
branches = self.group_spec.discrete_action_branches
if self.group_spec.action_shape == 1:
if self.group_spec.action_size == 1:
self._action_space = spaces.Discrete(branches[0])
else:
if flatten_branched:
Expand Down
21 changes: 21 additions & 0 deletions gym-unity/gym_unity/tests/test_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,27 @@ def test_branched_flatten():
assert isinstance(env.action_space, spaces.MultiDiscrete)


def test_action_space():
mock_env = mock.MagicMock()
mock_spec = create_mock_group_spec(
vector_action_space_type="discrete", vector_action_space_size=[5]
)
mock_decision_step, mock_terminal_step = create_mock_vector_steps(
mock_spec, num_agents=1
)
setup_mock_unityenvironment(
mock_env, mock_spec, mock_decision_step, mock_terminal_step
)

env = UnityToGymWrapper(mock_env, flatten_branched=True)
assert isinstance(env.action_space, spaces.Discrete)
assert env.action_space.n == 5

env = UnityToGymWrapper(mock_env, flatten_branched=False)
assert isinstance(env.action_space, spaces.Discrete)
assert env.action_space.n == 5


@pytest.mark.parametrize("use_uint8", [True, False], ids=["float", "uint8"])
def test_gym_wrapper_visual(use_uint8):
mock_env = mock.MagicMock()
Expand Down

0 comments on commit f591cfa

Please sign in to comment.