diff --git a/docs/api/environments/search_and_rescue.md b/docs/api/environments/search_and_rescue.md new file mode 100644 index 000000000..0748af9c2 --- /dev/null +++ b/docs/api/environments/search_and_rescue.md @@ -0,0 +1,11 @@ +::: jumanji.environments.swarms.search_and_rescue.env.SearchAndRescue + selection: + members: + - __init__ + - reset + - step + - observation_spec + - action_spec + - reward_spec + - render + - animate diff --git a/docs/environments/search_and_rescue.md b/docs/environments/search_and_rescue.md new file mode 100644 index 000000000..a707a8553 --- /dev/null +++ b/docs/environments/search_and_rescue.md @@ -0,0 +1,80 @@ +# 🚁 Search & Rescue + +[//]: # (TODO: Add animated plot) + +Multi-agent environment, modelling a group of agents searching a 2d environment +for multiple targets. Agents are individually rewarded for finding a target +that has not previously been detected. + +Each agent visualises a local region around itself, represented as a simple segmented +view of locations of other agents and targets in the vicinity. The environment +is updated in the following sequence: + +- The velocity of searching agents are updated, and consequently their positions. +- The positions of targets are updated. +- Targets within detection range, and within an agents view cone are marked as found. +- Agents are rewarded for locating previously unfound targets. +- Local views of the environment are generated for each searching agent. + +The agents are allotted a fixed number of steps to locate the targets. The search +space is a uniform square space, wrapped at the boundaries. + +Many aspects of the environment can be customised: + +- Agent observations can include targets as well as other searcher agents. +- Rewards can be shared by agents, or can be treated completely individually for individual + agents and can be scaled by time-step. +- Target dynamics can be customised to model various search scenarios. + +## Observations + +- `searcher_views`: jax array (float) of shape `(num_searchers, channels, num_vision)`. + Each agent generates an independent observation, an array of values representing the distance + along a ray from the agent to the nearest neighbour or target, with each cell representing a + ray angle (with `num_vision` rays evenly distributed over the agents field of vision). + For example if an agent sees another agent straight ahead and `num_vision = 5` then + the observation array could be + + ``` + [-1.0, -1.0, 0.5, -1.0, -1.0] + ``` + + where `-1.0` indicates there are no agents along that ray, and `0.5` is the normalised + distance to the other agent. Channels in the segmented view are used to differentiate + between different agents/targets and can be customised. By default, the view has three + channels representing other agents, found targets, and unlocated targets. +- `targets_remaining`: float in the range `[0, 1]`. The normalised number of targets + remaining to be detected (i.e. 1.0 when no targets have been found). +- `step`: int in the range `[0, time_limit]`. The current simulation step. +- `positions`: jax array (float) of shape `(num_searchers, 2)`. Agent coordinates. + +## Actions + +Jax array (float) of `(num_searchers, 2)` in the range `[-1, 1]`. Each entry in the +array represents an update of each agents velocity in the next step. Searching agents +update their velocity each step by rotating and accelerating/decelerating, where the +values are `[rotation, acceleration]`. Values are clipped to the range `[-1, 1]` +and then scaled by max rotation and acceleration parameters, i.e. the new values each +step are given by + +``` +heading = heading + max_rotation * action[0] +``` + +and speed + +``` +speed = speed + max_acceleration * action[1] +``` + +Once applied, agent speeds are clipped to velocities within a fixed range of speeds given +by the `min_speed` and `max_speed` parameters. + +## Rewards + +Jax array (float) of `(num_searchers,)`. Rewards are generated for each agent individually. + +Agents are rewarded +1 for locating a target that has not already been detected. It is possible +for multiple agents to detect a target inside a step, as such rewards can either be shared +by the locating agents, or each individual agent can get the full reward. Rewards provided can +also be scaled by simulation step to encourage agents to develop efficient search patterns. diff --git a/jumanji/__init__.py b/jumanji/__init__.py index 60ba1da88..ec88111c3 100644 --- a/jumanji/__init__.py +++ b/jumanji/__init__.py @@ -142,3 +142,10 @@ # LevelBasedForaging with a random generator with 8 grid size, # 2 agents and 2 food items and the maximum agent's level is 2. register(id="LevelBasedForaging-v0", entry_point="jumanji.environments:LevelBasedForaging") + +### +# Swarm Environments +### + +# Search-and-Rescue environment +register(id="SearchAndRescue-v0", entry_point="jumanji.environments:SearchAndRescue") diff --git a/jumanji/environments/__init__.py b/jumanji/environments/__init__.py index b8fca6b35..cb2518826 100644 --- a/jumanji/environments/__init__.py +++ b/jumanji/environments/__init__.py @@ -59,6 +59,7 @@ from jumanji.environments.routing.snake.env import Snake from jumanji.environments.routing.sokoban.env import Sokoban from jumanji.environments.routing.tsp.env import TSP +from jumanji.environments.swarms.search_and_rescue.env import SearchAndRescue def is_colab() -> bool: diff --git a/jumanji/environments/routing/snake/types.py b/jumanji/environments/routing/snake/types.py index 7617240e8..58b9b4d7b 100644 --- a/jumanji/environments/routing/snake/types.py +++ b/jumanji/environments/routing/snake/types.py @@ -40,10 +40,10 @@ def __add__(self, other: Position) -> Position: # type: ignore[override] class Actions(IntEnum): - UP: int = 0 - RIGHT: int = 1 - DOWN: int = 2 - LEFT: int = 3 + UP = 0 + RIGHT = 1 + DOWN = 2 + LEFT = 3 @dataclass diff --git a/jumanji/environments/swarms/__init__.py b/jumanji/environments/swarms/__init__.py new file mode 100644 index 000000000..21db9ec1c --- /dev/null +++ b/jumanji/environments/swarms/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jumanji/environments/swarms/common/__init__.py b/jumanji/environments/swarms/common/__init__.py new file mode 100644 index 000000000..21db9ec1c --- /dev/null +++ b/jumanji/environments/swarms/common/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jumanji/environments/swarms/common/common_test.py b/jumanji/environments/swarms/common/common_test.py new file mode 100644 index 000000000..fa7559fb4 --- /dev/null +++ b/jumanji/environments/swarms/common/common_test.py @@ -0,0 +1,188 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple + +import jax.numpy as jnp +import matplotlib +import matplotlib.pyplot as plt +import pytest + +from jumanji.environments.swarms.common import types, updates, viewer + + +@pytest.fixture +def params() -> types.AgentParams: + return types.AgentParams( + max_rotate=0.5, + max_accelerate=0.01, + min_speed=0.01, + max_speed=0.05, + view_angle=0.5, + ) + + +@pytest.mark.parametrize( + "heading, speed, actions, expected", + [ + [0.0, 0.01, [1.0, 0.0], (0.5 * jnp.pi, 0.01)], + [0.0, 0.01, [-1.0, 0.0], (1.5 * jnp.pi, 0.01)], + [jnp.pi, 0.01, [1.0, 0.0], (1.5 * jnp.pi, 0.01)], + [jnp.pi, 0.01, [-1.0, 0.0], (0.5 * jnp.pi, 0.01)], + [1.75 * jnp.pi, 0.01, [1.0, 0.0], (0.25 * jnp.pi, 0.01)], + [0.0, 0.01, [0.0, 1.0], (0.0, 0.02)], + [0.0, 0.01, [0.0, -1.0], (0.0, 0.01)], + [0.0, 0.02, [0.0, -1.0], (0.0, 0.01)], + [0.0, 0.05, [0.0, -1.0], (0.0, 0.04)], + [0.0, 0.05, [0.0, 1.0], (0.0, 0.05)], + ], +) +def test_velocity_update( + params: types.AgentParams, + heading: float, + speed: float, + actions: List[float], + expected: Tuple[float, float], +) -> None: + state = types.AgentState( + pos=jnp.zeros((1, 2)), + heading=jnp.array([heading]), + speed=jnp.array([speed]), + ) + actions = jnp.array([actions]) + + new_heading, new_speed = updates.update_velocity(params, (actions, state)) + + assert jnp.isclose(new_heading[0], expected[0]) + assert jnp.isclose(new_speed[0], expected[1]) + + +@pytest.mark.parametrize( + "pos, heading, speed, expected, env_size", + [ + [[0.0, 0.5], 0.0, 0.1, [0.1, 0.5], 1.0], + [[0.0, 0.5], jnp.pi, 0.1, [0.9, 0.5], 1.0], + [[0.5, 0.0], 0.5 * jnp.pi, 0.1, [0.5, 0.1], 1.0], + [[0.5, 0.0], 1.5 * jnp.pi, 0.1, [0.5, 0.9], 1.0], + [[0.4, 0.2], 0.0, 0.2, [0.1, 0.2], 0.5], + [[0.1, 0.2], jnp.pi, 0.2, [0.4, 0.2], 0.5], + [[0.2, 0.4], 0.5 * jnp.pi, 0.2, [0.2, 0.1], 0.5], + [[0.2, 0.1], 1.5 * jnp.pi, 0.2, [0.2, 0.4], 0.5], + ], +) +def test_move( + pos: List[float], heading: float, speed: float, expected: List[float], env_size: float +) -> None: + pos = jnp.array(pos) + new_pos = updates.move(pos, heading, speed, env_size) + + assert jnp.allclose(new_pos, jnp.array(expected)) + + +@pytest.mark.parametrize( + "pos, heading, speed, actions, expected_pos, expected_heading, expected_speed, env_size", + [ + [[0.0, 0.5], 0.0, 0.01, [0.0, 0.0], [0.01, 0.5], 0.0, 0.01, 1.0], + [[0.5, 0.0], 0.0, 0.01, [1.0, 0.0], [0.5, 0.01], 0.5 * jnp.pi, 0.01, 1.0], + [[0.5, 0.0], 0.0, 0.01, [-1.0, 0.0], [0.5, 0.99], 1.5 * jnp.pi, 0.01, 1.0], + [[0.0, 0.5], 0.0, 0.01, [0.0, 1.0], [0.02, 0.5], 0.0, 0.02, 1.0], + [[0.0, 0.5], 0.0, 0.01, [0.0, -1.0], [0.01, 0.5], 0.0, 0.01, 1.0], + [[0.0, 0.5], 0.0, 0.05, [0.0, 1.0], [0.05, 0.5], 0.0, 0.05, 1.0], + [[0.495, 0.25], 0.0, 0.01, [0.0, 0.0], [0.005, 0.25], 0.0, 0.01, 0.5], + [[0.25, 0.005], 1.5 * jnp.pi, 0.01, [0.0, 0.0], [0.25, 0.495], 1.5 * jnp.pi, 0.01, 0.5], + ], +) +def test_state_update( + params: types.AgentParams, + pos: List[float], + heading: float, + speed: float, + actions: List[float], + expected_pos: List[float], + expected_heading: float, + expected_speed: float, + env_size: float, +) -> None: + state = types.AgentState( + pos=jnp.array([pos]), + heading=jnp.array([heading]), + speed=jnp.array([speed]), + ) + actions = jnp.array([actions]) + + new_state = updates.update_state(env_size, params, state, actions) + + assert isinstance(new_state, types.AgentState) + assert jnp.allclose(new_state.pos, jnp.array([expected_pos])) + assert jnp.allclose(new_state.heading, jnp.array([expected_heading])) + assert jnp.allclose(new_state.speed, jnp.array([expected_speed])) + + +def test_view_reduction() -> None: + view_a = jnp.array([-1.0, -1.0, 0.2, 0.2, 0.5]) + view_b = jnp.array([-1.0, 0.2, -1.0, 0.5, 0.2]) + result = updates.view_reduction_fn(view_a, view_b) + assert jnp.allclose(result, jnp.array([-1.0, 0.2, 0.2, 0.2, 0.2])) + + +@pytest.mark.parametrize( + "pos, view_angle, env_size, expected", + [ + [[0.05, 0.0], 0.5, 1.0, [-1.0, -1.0, 0.5, -1.0, -1.0]], + [[0.0, 0.05], 0.5, 1.0, [0.5, -1.0, -1.0, -1.0, -1.0]], + [[0.0, 0.95], 0.5, 1.0, [-1.0, -1.0, -1.0, -1.0, 0.5]], + [[0.95, 0.0], 0.5, 1.0, [-1.0, -1.0, -1.0, -1.0, -1.0]], + [[0.05, 0.0], 0.25, 1.0, [-1.0, -1.0, 0.5, -1.0, -1.0]], + [[0.0, 0.05], 0.25, 1.0, [-1.0, -1.0, -1.0, -1.0, -1.0]], + [[0.0, 0.95], 0.25, 1.0, [-1.0, -1.0, -1.0, -1.0, -1.0]], + [[0.01, 0.0], 0.5, 1.0, [-1.0, 0.1, 0.1, 0.1, -1.0]], + [[0.0, 0.45], 0.5, 1.0, [4.5, -1.0, -1.0, -1.0, -1.0]], + [[0.0, 0.45], 0.5, 0.5, [-1.0, -1.0, -1.0, -1.0, 0.5]], + ], +) +def test_view(pos: List[float], view_angle: float, env_size: float, expected: List[float]) -> None: + state_a = types.AgentState( + pos=jnp.zeros((2,)), + heading=0.0, + speed=0.0, + ) + + state_b = types.AgentState( + pos=jnp.array(pos), + heading=0.0, + speed=0.0, + ) + + obs = updates.view( + (view_angle, 0.02), state_a, state_b, n_view=5, i_range=0.1, env_size=env_size + ) + assert jnp.allclose(obs, jnp.array(expected)) + + +def test_viewer_utils() -> None: + f, ax = plt.subplots() + f, ax = viewer.format_plot(f, ax, (1.0, 1.0)) + + assert isinstance(f, matplotlib.figure.Figure) + assert isinstance(ax, matplotlib.axes.Axes) + + state = types.AgentState( + pos=jnp.zeros((3, 2)), + heading=jnp.zeros((3,)), + speed=jnp.zeros((3,)), + ) + + quiver = viewer.draw_agents(ax, state, "red") + + assert isinstance(quiver, matplotlib.quiver.Quiver) diff --git a/jumanji/environments/swarms/common/types.py b/jumanji/environments/swarms/common/types.py new file mode 100644 index 000000000..db3e0e0f4 --- /dev/null +++ b/jumanji/environments/swarms/common/types.py @@ -0,0 +1,53 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dataclasses import dataclass +else: + from chex import dataclass + +import chex + + +@dataclass(frozen=True) +class AgentParams: + """ + max_rotate: Max angle an agent can rotate during a step (a fraction of pi) + max_accelerate: Max change in speed during a step + min_speed: Minimum agent speed + max_speed: Maximum agent speed + view_angle: Agent view angle, as a fraction of pi either side of its heading + """ + + max_rotate: float + max_accelerate: float + min_speed: float + max_speed: float + view_angle: float + + +@dataclass +class AgentState: + """ + State of multiple agents of a single type + + pos: 2d position of the (centre of the) agents + heading: Heading of the agents (in radians) + speed: Speed of the agents + """ + + pos: chex.Array # (num_agents, 2) + heading: chex.Array # (num_agents,) + speed: chex.Array # (num_agents,) diff --git a/jumanji/environments/swarms/common/updates.py b/jumanji/environments/swarms/common/updates.py new file mode 100644 index 000000000..25f99e1e5 --- /dev/null +++ b/jumanji/environments/swarms/common/updates.py @@ -0,0 +1,207 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import chex +import esquilax +import jax +import jax.numpy as jnp + +from jumanji.environments.swarms.common import types + + +@esquilax.transforms.amap +def update_velocity( + params: types.AgentParams, + x: Tuple[chex.Array, types.AgentState], +) -> Tuple[chex.Numeric, chex.Numeric]: + """ + Get the updated agent heading and speeds from actions + + Args: + params: Agent parameters. + x: Agent rotation and acceleration actions. + + Returns: + float: New agent heading. + float: New agent speed. + """ + actions, boid = x + rotation = actions[0] * params.max_rotate * jnp.pi + acceleration = actions[1] * params.max_accelerate + + new_heading = (boid.heading + rotation) % (2 * jnp.pi) + new_speeds = jnp.clip( + boid.speed + acceleration, + min=params.min_speed, + max=params.max_speed, + ) + + return new_heading, new_speeds + + +def move(pos: chex.Array, heading: chex.Array, speed: chex.Array, env_size: float) -> chex.Array: + """ + Get updated agent positions from current speed and heading + + Args: + pos: Agent position. + heading: Agent heading (angle). + speed: Agent speed. + env_size: Size of the environment. + + Returns: + jax array (float32): Updated agent positions. + """ + d_pos = jnp.array([speed * jnp.cos(heading), speed * jnp.sin(heading)]) + return (pos + d_pos) % env_size + + +def update_state( + env_size: float, + params: types.AgentParams, + state: types.AgentState, + actions: chex.Array, +) -> types.AgentState: + """ + Update the state of a group of agents from a sample of actions + + Args: + env_size: Size of the environment. + params: Agent parameters. + state: Current agent states. + actions: Agent actions, i.e. a 2D array of action for each agent. + + Returns: + AgentState: Updated state of the agents after applying steering + actions and updating positions. + """ + actions = jnp.clip(actions, min=-1.0, max=1.0) + headings, speeds = update_velocity(params, (actions, state)) + positions = jax.vmap(move, in_axes=(0, 0, 0, None))(state.pos, headings, speeds, env_size) + + return types.AgentState( + pos=positions, + speed=speeds, + heading=headings, + ) + + +def view_reduction_fn(view_a: chex.Array, view_b: chex.Array) -> chex.Array: + """ + Binary view reduction function for use in Esquilax spatial transformation. + + Handles reduction where a value of -1.0 indicates no + agent in view-range. Returns the min value if they + are both positive, but the max value if one or both of + the values is -1.0. + + Args: + view_a: View vector. + view_b: View vector. + + Returns: + jax array (float32): View vector indicating the + shortest distance to the nearest neighbour or + -1.0 if no agent is present along a ray. + """ + return jnp.where( + jnp.logical_or(view_a < 0.0, view_b < 0.0), + jnp.maximum(view_a, view_b), + jnp.minimum(view_a, view_b), + ) + + +def angular_width( + viewing_pos: chex.Array, + viewed_pos: chex.Array, + viewing_heading: chex.Array, + i_range: float, + agent_radius: float, + env_size: float, +) -> Tuple[chex.Array, chex.Array, chex.Array]: + """ + Get the normalised distance, and angles to edges of another agent. + + Args: + viewing_pos: Co-ordinates of the viewing agent + viewed_pos: Co-ordinates of the viewed agent + viewing_heading: Heading of viewing agent + i_range: Interaction range + agent_radius: Agent visual radius + env_size: Environment size + + Returns: + Normalised distance between agents, and the left and right + angles to the edges of the agent. + """ + dx = esquilax.utils.shortest_vector(viewing_pos, viewed_pos, length=env_size) + dist = jnp.sqrt(jnp.sum(dx * dx)) + phi = jnp.arctan2(dx[1], dx[0]) % (2 * jnp.pi) + dh = esquilax.utils.shortest_vector(phi, viewing_heading, 2 * jnp.pi) + a_width = jnp.arctan2(agent_radius, dist) + norm_dist = dist / i_range + return norm_dist, dh - a_width, dh + a_width + + +def view( + params: Tuple[float, float], + viewing_agent: types.AgentState, + viewed_agent: types.AgentState, + *, + n_view: int, + i_range: float, + env_size: float, +) -> chex.Array: + """ + Simple agent view model + + Simple view model where the agents view angle is subdivided + into an array of values representing the distance from + the agent along rays from the agent, with rays evenly distributed + across the agents field of view. The limit of vision is set at 1.0. + The default value if no object is within range is -1.0. + Currently, this model assumes the viewed agent/objects are circular. + + Args: + params: Tuple containing agent view angle and view-radius. + viewing_agent: Viewing agent state. + viewed_agent: State of agent being viewed. + n_view: Static number of view rays/subdivisions (i.e. how + many cells the resulting array contains). + i_range: Static agent view/interaction range. + env_size: Size of the environment. + + Returns: + jax array (float32): 1D array representing the distance + along a ray from the agent to another agent. + """ + view_angle, agent_radius = params + rays = jnp.linspace( + -view_angle * jnp.pi, + view_angle * jnp.pi, + n_view, + endpoint=True, + ) + d, left, right = angular_width( + viewing_agent.pos, + viewed_agent.pos, + viewing_agent.heading, + i_range, + agent_radius, + env_size, + ) + obs = jnp.where(jnp.logical_and(left < rays, rays < right), d, -1.0) + return obs diff --git a/jumanji/environments/swarms/common/viewer.py b/jumanji/environments/swarms/common/viewer.py new file mode 100644 index 000000000..4fc15c88b --- /dev/null +++ b/jumanji/environments/swarms/common/viewer.py @@ -0,0 +1,76 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import jax.numpy as jnp +from matplotlib.axes import Axes +from matplotlib.figure import Figure +from matplotlib.quiver import Quiver + +from jumanji.environments.swarms.common.types import AgentState + + +def draw_agents(ax: Axes, agent_states: AgentState, color: str) -> Quiver: + """Draw a flock/swarm of agent using a matplotlib quiver + + Args: + ax: Plot axes. + agent_states: Flock/swarm agent states. + color: Fill color of agents. + + Returns: + `Quiver`: Matplotlib quiver, can also be used and + updated when animating. + """ + q = ax.quiver( + agent_states.pos[:, 0], + agent_states.pos[:, 1], + jnp.cos(agent_states.heading), + jnp.sin(agent_states.heading), + color=color, + pivot="middle", + ) + return q + + +def format_plot( + fig: Figure, ax: Axes, env_dims: Tuple[float, float], border: float = 0.01 +) -> Tuple[Figure, Axes]: + """Format a flock/swarm plot, remove ticks and bound to the environment dimensions. + + Args: + fig: Matplotlib figure. + ax: Matplotlib axes. + env_dims: Environment dimensions (i.e. its boundaries). + border: Border padding to apply around plot. + + Returns: + Figure: Formatted figure. + Axes: Formatted axes. + """ + fig.subplots_adjust( + top=1.0 - border, + bottom=border, + right=1.0 - border, + left=border, + hspace=0, + wspace=0, + ) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_xlim(0, env_dims[0]) + ax.set_ylim(0, env_dims[1]) + + return fig, ax diff --git a/jumanji/environments/swarms/search_and_rescue/__init__.py b/jumanji/environments/swarms/search_and_rescue/__init__.py new file mode 100644 index 000000000..f4959771a --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .env import SearchAndRescue diff --git a/jumanji/environments/swarms/search_and_rescue/conftest.py b/jumanji/environments/swarms/search_and_rescue/conftest.py new file mode 100644 index 000000000..8158bacac --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/conftest.py @@ -0,0 +1,81 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import chex +import jax.random +import pytest + +from jumanji.environments.swarms.search_and_rescue import SearchAndRescue, observations + + +@pytest.fixture +def env() -> SearchAndRescue: + return SearchAndRescue( + target_contact_range=0.05, + searcher_max_rotate=0.2, + searcher_max_accelerate=0.01, + searcher_min_speed=0.01, + searcher_max_speed=0.05, + searcher_view_angle=0.5, + time_limit=10, + ) + + +class FixtureRequest: + """Just used for typing""" + + param: observations.ObservationFn + + +@pytest.fixture( + params=[ + observations.AgentObservationFn( + num_vision=32, + vision_range=0.1, + view_angle=0.5, + agent_radius=0.01, + env_size=1.0, + ), + observations.AgentAndTargetObservationFn( + num_vision=32, + vision_range=0.1, + view_angle=0.5, + agent_radius=0.01, + env_size=1.0, + ), + observations.AgentAndAllTargetObservationFn( + num_vision=32, + vision_range=0.1, + view_angle=0.5, + agent_radius=0.01, + env_size=1.0, + ), + ] +) +def multi_obs_env(request: FixtureRequest) -> SearchAndRescue: + return SearchAndRescue( + target_contact_range=0.05, + searcher_max_rotate=0.2, + searcher_max_accelerate=0.01, + searcher_min_speed=0.01, + searcher_max_speed=0.05, + searcher_view_angle=0.5, + time_limit=10, + observation=request.param, + ) + + +@pytest.fixture +def key() -> chex.PRNGKey: + return jax.random.PRNGKey(101) diff --git a/jumanji/environments/swarms/search_and_rescue/dynamics.py b/jumanji/environments/swarms/search_and_rescue/dynamics.py new file mode 100644 index 000000000..1a6897769 --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/dynamics.py @@ -0,0 +1,69 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +import chex +import jax + +from jumanji.environments.swarms.search_and_rescue.types import TargetState + + +class TargetDynamics(abc.ABC): + @abc.abstractmethod + def __call__(self, key: chex.PRNGKey, targets: TargetState, env_size: float) -> TargetState: + """Interface for target state update function. + + NOTE: Target positions should be inside the bounds + of the environment. Out-of-bound co-ordinates can + lead to unexpected behaviour. + + Args: + key: Random key. + targets: Current target states. + env_size: Environment size. + + Returns: + Updated target states. + """ + + +class RandomWalk(TargetDynamics): + def __init__(self, step_size: float): + """ + Simple random walk target dynamics. + + Target positions are updated with random steps, sampled uniformly + from the range `[-step-size, step-size]`. + + Args: + step_size: Maximum random step-size in each axis. + """ + self.step_size = step_size + + def __call__(self, key: chex.PRNGKey, targets: TargetState, env_size: float) -> TargetState: + """Update target state. + + Args: + key: Random key. + targets: Current target states. + env_size: Environment size. + + Returns: + Updated target states. + """ + d_pos = jax.random.uniform(key, targets.pos.shape) + d_pos = self.step_size * 2.0 * (d_pos - 0.5) + pos = (targets.pos + d_pos) % env_size + return TargetState(pos=pos, vel=targets.vel, found=targets.found) diff --git a/jumanji/environments/swarms/search_and_rescue/env.py b/jumanji/environments/swarms/search_and_rescue/env.py new file mode 100644 index 000000000..ad7c7789a --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/env.py @@ -0,0 +1,395 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import cached_property, partial +from typing import Optional, Sequence, Tuple + +import chex +import esquilax +import jax +import jax.numpy as jnp +from esquilax.transforms import spatial +from matplotlib.animation import FuncAnimation + +from jumanji import specs +from jumanji.env import Environment +from jumanji.environments.swarms.common.types import AgentParams +from jumanji.environments.swarms.common.updates import update_state +from jumanji.environments.swarms.search_and_rescue import utils +from jumanji.environments.swarms.search_and_rescue.dynamics import RandomWalk, TargetDynamics +from jumanji.environments.swarms.search_and_rescue.generator import Generator, RandomGenerator +from jumanji.environments.swarms.search_and_rescue.observations import ( + AgentAndAllTargetObservationFn, + ObservationFn, +) +from jumanji.environments.swarms.search_and_rescue.reward import RewardFn, SharedRewardFn +from jumanji.environments.swarms.search_and_rescue.types import Observation, State, TargetState +from jumanji.environments.swarms.search_and_rescue.viewer import SearchAndRescueViewer +from jumanji.types import TimeStep, restart, termination, transition +from jumanji.viewer import Viewer + + +class SearchAndRescue(Environment): + """A multi-agent search environment + + Environment modelling a collection of agents collectively searching + for a set of targets on a 2d environment. Agents are rewarded + (individually) for coming within a fixed range of a target that has + not already been detected. Agents visualise their local environment + (i.e. the location of other agents and targets) via a simple segmented + view model. The environment area is a uniform square space with wrapped + boundaries. + + An episode will terminate if all targets have been located by the team of + searching agents. + + - observation: `Observation` + searcher_views: jax array (float) of shape (num_searchers, channels, num_vision) + Individual local views of positions of other agents and targets, where + channels can be used to differentiate between agents and targets. + Each entry in the view indicates the distant to another agent/target + along a ray from the agent, and is -1.0 if nothing is in range along the ray. + The view model can be customised using an `ObservationFn` implementation, e.g. + the view can include agents and all targets, agents and found targets,or + just other agents. + targets_remaining: (float) Number of targets remaining to be found from + the total scaled to the range [0, 1] (i.e. a value of 1.0 indicates + all the targets are still to be found). + step: (int) current simulation step. + positions: jax array (float) of shape (num_searchers, 2) search agent positions. + + - action: jax array (float) of shape (num_searchers, 2) + Array of individual agent actions. Each agents actions rotate and + accelerate/decelerate the agent as [rotation, acceleration] on the range + [-1, 1]. These values are then scaled to update agent velocities within + given parameters (i.e. a value of -+1 is the maximum acceleration/rotation). + + - reward: jax array (float) of shape (num_searchers,) + Arrays of individual agent rewards. A reward of +1 is granted when an agent + comes into contact range with a target that has not yet been found, and + the target is within the searchers view cone. Rewards can be shared + between agents if a target is simultaneously detected by multiple agents, + or each can be provided the full reward individually. + + - state: `State` + - searchers: `AgentState` + - pos: jax array (float) of shape (num_searchers, 2) in the range [0, env_size]. + - heading: jax array (float) of shape (num_searcher,) in + the range [0, 2π]. + - speed: jax array (float) of shape (num_searchers,) in the + range [min_speed, max_speed]. + - targets: `TargetState` + - pos: jax array (float) of shape (num_targets, 2) in the range [0, env_size]. + - vel: jax array (float) of shape (num_targets, 2). + - found: jax array (bool) of shape (num_targets,) flag indicating if + target has been located by an agent. + - key: jax array (uint32) of shape (2,) + - step: int representing the current simulation step. + + ```python + from jumanji.environments import SearchAndRescue + + env = SearchAndRescue() + key = jax.random.PRNGKey(0) + state, timestep = jax.jit(env.reset)(key) + env.render(state) + action = env.action_spec.generate_value() + state, timestep = jax.jit(env.step)(state, action) + env.render(state) + ``` + """ + + def __init__( + self, + target_contact_range: float = 0.04, + searcher_max_rotate: float = 0.25, + searcher_max_accelerate: float = 0.005, + searcher_min_speed: float = 0.01, + searcher_max_speed: float = 0.02, + searcher_view_angle: float = 0.5, + time_limit: int = 400, + viewer: Optional[Viewer[State]] = None, + target_dynamics: Optional[TargetDynamics] = None, + generator: Optional[Generator] = None, + reward_fn: Optional[RewardFn] = None, + observation: Optional[ObservationFn] = None, + ) -> None: + """Instantiates a `SearchAndRescue` environment + + Args: + target_contact_range: Range at which a searchers will 'find' a target. + searcher_max_rotate: Maximum rotation searcher agents can + turn within a step. Should be a value from [0,1] + representing a fraction of π-radians. + searcher_max_accelerate: Magnitude of the maximum + acceleration/deceleration a searcher agent can apply within a step. + searcher_min_speed: Minimum speed a searcher agent can move at. + searcher_max_speed: Maximum speed a searcher agent can move at. + searcher_view_angle: Searcher agent local view angle. Should be + a value from [0,1] representing a fraction of π-radians. + The view cone of an agent goes from +- of the view angle + relative to its heading, e.g. 0.5 would mean searchers have a + 90° view angle in total. + time_limit: Maximum number of environment steps allowed for search. + viewer: `Viewer` used for rendering. Defaults to `SearchAndRescueViewer`. + target_dynamics: Target object dynamics model, implemented as a + `TargetDynamics` interface. Defaults to `RandomWalk`. + generator: Initial state `Generator` instance. Defaults to `RandomGenerator` + with 20 targets and 10 searchers. + reward_fn: Reward aggregation function. Defaults to `SharedRewardFn` where + agents share rewards if they locate a target simultaneously. + observation: Agent observation view generation function. Defaults to + `AgentAndAllTargetObservationFn` where all targets (found and unfound) + and other searching agents are included in the generated view. + """ + + self.target_contact_range = target_contact_range + + self.searcher_params = AgentParams( + max_rotate=searcher_max_rotate, + max_accelerate=searcher_max_accelerate, + min_speed=searcher_min_speed, + max_speed=searcher_max_speed, + view_angle=searcher_view_angle, + ) + self.time_limit = time_limit + self._target_dynamics = target_dynamics or RandomWalk(0.001) + self.generator = generator or RandomGenerator(num_targets=50, num_searchers=2) + self._viewer = viewer or SearchAndRescueViewer() + self._reward_fn = reward_fn or SharedRewardFn() + self._observation = observation or AgentAndAllTargetObservationFn( + num_vision=64, + vision_range=0.25, + view_angle=searcher_view_angle, + agent_radius=0.02, + env_size=self.generator.env_size, + ) + super().__init__() + + def __repr__(self) -> str: + return "\n".join( + [ + "Search & rescue multi-agent environment:", + f" - num searchers: {self.generator.num_searchers}", + f" - num targets: {self.generator.num_targets}", + f" - max searcher rotation: {self.searcher_params.max_rotate}", + f" - max searcher acceleration: {self.searcher_params.max_accelerate}", + f" - searcher min speed: {self.searcher_params.min_speed}", + f" - searcher max speed: {self.searcher_params.max_speed}", + f" - search vision range: {self._observation.vision_range}", + f" - search view angle: {self._observation.view_angle}", + f" - target contact range: {self.target_contact_range}", + f" - num vision: {self._observation.num_vision}", + f" - agent radius: {self._observation.agent_radius}", + f" - time limit: {self.time_limit}," + f" - env size: {self.generator.env_size}" + f" - target dynamics: {self._target_dynamics.__class__.__name__}", + f" - generator: {self.generator.__class__.__name__}", + f" - reward fn: {self._reward_fn.__class__.__name__}", + f" - observation fn: {self._observation.__class__.__name__}", + ] + ) + + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: + """Initialise searcher and target initial states. + + Args: + key: Random key used to reset the environment. + + Returns: + state: Initial environment state. + timestep: TimeStep with individual search agent views. + """ + state = self.generator(key, self.searcher_params) + timestep = restart(observation=self._state_to_observation(state), shape=(self.num_agents,)) + return state, timestep + + def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Observation]]: + """Environment update. + + Update searcher velocities and consequently their positions, + mark found targets, and generate rewards and local observations. + + Args: + state: Environment state. + actions: 2d array of searcher steering actions. + + Returns: + state: Updated searcher and target positions and velocities. + timestep: Transition timestep with individual agent local observations. + """ + key, target_key = jax.random.split(state.key, num=2) + searchers = update_state( + self.generator.env_size, self.searcher_params, state.searchers, actions + ) + targets = self._target_dynamics(target_key, state.targets, self.generator.env_size) + + # Searchers return an array of flags of any targets they are in range of, + # and that have not already been located, result shape here is (n-searcher, n-targets) + targets_found = spatial( + utils.searcher_detect_targets, + reduction=esquilax.reductions.logical_or((self.generator.num_targets,)), + i_range=self.target_contact_range, + dims=self.generator.env_size, + )( + self.searcher_params.view_angle, + searchers, + (jnp.arange(self.generator.num_targets), targets), + pos=searchers.pos, + pos_b=targets.pos, + env_size=self.generator.env_size, + n_targets=self.generator.num_targets, + ) + + rewards = self._reward_fn(targets_found, state.step, self.time_limit) + + targets_found = jnp.any(targets_found, axis=0) + # Targets need to remain found if they already have been + targets_found = jnp.logical_or(targets_found, state.targets.found) + + state = State( + searchers=searchers, + targets=TargetState(pos=targets.pos, vel=targets.vel, found=targets_found), + key=key, + step=state.step + 1, + ) + observation = self._state_to_observation(state) + observation = jax.lax.stop_gradient(observation) + timestep = jax.lax.cond( + jnp.logical_or(state.step >= self.time_limit, jnp.all(targets_found)), + partial(termination, shape=(self.num_agents,)), + partial(transition, shape=(self.num_agents,)), + rewards, + observation, + ) + return state, timestep + + def _state_to_observation(self, state: State) -> Observation: + searcher_views = self._observation(state) + return Observation( + searcher_views=searcher_views, + targets_remaining=1.0 - jnp.sum(state.targets.found) / self.generator.num_targets, + step=state.step, + positions=state.searchers.pos / self.generator.env_size, + ) + + @cached_property + def num_agents(self) -> int: + return self.generator.num_searchers + + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: + """Returns the observation spec. + + Local searcher agent views representing the distance to the + closest neighbouring agents and targets in the environment. + + Returns: + observation_spec: Search-and-rescue observation spec + """ + searcher_views = specs.BoundedArray( + shape=( + self.num_agents, + self._observation.num_channels, + self._observation.num_vision, + ), + minimum=-1.0, + maximum=1.0, + dtype=float, + name="searcher_views", + ) + return specs.Spec( + Observation, + "ObservationSpec", + searcher_views=searcher_views, + targets_remaining=specs.BoundedArray( + shape=(), minimum=0.0, maximum=1.0, name="targets_remaining", dtype=jnp.float32 + ), + step=specs.BoundedArray( + shape=(), minimum=0, maximum=self.time_limit, name="step", dtype=jnp.int32 + ), + positions=specs.BoundedArray( + shape=(self.num_agents, 2), + minimum=0.0, + maximum=1.0, + name="positions", + dtype=jnp.float32, + ), + ) + + @cached_property + def action_spec(self) -> specs.BoundedArray: + """Returns the action spec. + + 2d array of individual agent actions. Each agents action is + an array representing [rotation, acceleration] in the range + [-1, 1]. + + Returns: + action_spec: Action array spec + """ + return specs.BoundedArray( + shape=(self.generator.num_searchers, 2), + minimum=-1.0, + maximum=1.0, + dtype=float, + ) + + @cached_property + def reward_spec(self) -> specs.BoundedArray: + """Returns the reward spec. + + Array of individual rewards for each agent. + + Returns: + reward_spec: Reward array spec. + """ + return specs.BoundedArray( + shape=(self.generator.num_searchers,), + minimum=0.0, + maximum=float(self.generator.num_targets), + dtype=float, + ) + + def render(self, state: State) -> None: + """Render a frame of the environment for a given state using matplotlib. + + Args: + state: State object. + """ + self._viewer.render(state) + + def animate( + self, + states: Sequence[State], + interval: int = 100, + save_path: Optional[str] = None, + ) -> FuncAnimation: + """Create an animation from a sequence of environment states. + + Args: + states: sequence of environment states corresponding to consecutive + timesteps. + interval: delay between frames in milliseconds. + save_path: the path where the animation file should be saved. If it + is None, the plot will not be saved. + + Returns: + Animation that can be saved as a GIF, MP4, or rendered with HTML. + """ + return self._viewer.animate(states, interval=interval, save_path=save_path) + + def close(self) -> None: + """Perform any necessary cleanup.""" + self._viewer.close() diff --git a/jumanji/environments/swarms/search_and_rescue/env_test.py b/jumanji/environments/swarms/search_and_rescue/env_test.py new file mode 100644 index 000000000..743546194 --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/env_test.py @@ -0,0 +1,276 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Tuple + +import chex +import jax +import jax.numpy as jnp +import matplotlib +import matplotlib.pyplot as plt +import py +import pytest + +from jumanji.environments.swarms.common.types import AgentParams, AgentState +from jumanji.environments.swarms.search_and_rescue import SearchAndRescue +from jumanji.environments.swarms.search_and_rescue.dynamics import RandomWalk +from jumanji.environments.swarms.search_and_rescue.types import ( + Observation, + State, + TargetState, +) +from jumanji.testing.env_not_smoke import check_env_does_not_smoke, check_env_specs_does_not_smoke +from jumanji.types import StepType, TimeStep + + +def test_env_init(env: SearchAndRescue, key: chex.PRNGKey) -> None: + """ + Check newly initialised state has expected array shapes + and initial timestep. + """ + state, timestep = env.reset(key) + assert isinstance(state, State) + + assert isinstance(state.searchers, AgentState) + assert state.searchers.pos.shape == (env.generator.num_searchers, 2) + assert state.searchers.speed.shape == (env.generator.num_searchers,) + assert state.searchers.heading.shape == (env.generator.num_searchers,) + + assert isinstance(state.targets, TargetState) + assert state.targets.pos.shape == (env.generator.num_targets, 2) + assert state.targets.vel.shape == (env.generator.num_targets, 2) + assert state.targets.found.shape == (env.generator.num_targets,) + assert jnp.array_equal( + state.targets.found, jnp.full((env.generator.num_targets,), False, dtype=bool) + ) + assert state.step == 0 + + assert isinstance(timestep.observation, Observation) + assert timestep.observation.searcher_views.shape == ( + env.generator.num_searchers, + env._observation.num_channels, + env._observation.num_vision, + ) + assert timestep.step_type == StepType.FIRST + assert timestep.reward.shape == (env.generator.num_searchers,) + + +@pytest.mark.parametrize("env_size", [1.0, 0.2, 10.0]) +def test_env_step(env: SearchAndRescue, key: chex.PRNGKey, env_size: float) -> None: + """ + Run several steps of the environment with random actions and + check states (i.e. positions, heading, speeds) all fall + inside expected ranges. + """ + n_steps = env.time_limit + env.generator.env_size = env_size + env.time_limit = 22 + + def step( + carry: Tuple[chex.PRNGKey, State], _: None + ) -> Tuple[Tuple[chex.PRNGKey, State], Tuple[State, TimeStep[Observation]]]: + k, state = carry + k, k_search = jax.random.split(k) + actions = jax.random.uniform(k_search, (env.num_agents, 2), minval=-1.0, maxval=1.0) + new_state, timestep = env.step(state, actions) + return (k, new_state), (state, timestep) + + init_state, _ = env.reset(key) + (_, final_state), (state_history, timesteps) = jax.lax.scan( + step, (key, init_state), length=n_steps + ) + + assert isinstance(state_history, State) + + assert state_history.searchers.pos.shape == (n_steps, env.num_agents, 2) + assert jnp.all((0.0 <= state_history.searchers.pos) & (state_history.searchers.pos <= env_size)) + assert state_history.searchers.speed.shape == (n_steps, env.num_agents) + assert jnp.all( + (env.searcher_params.min_speed <= state_history.searchers.speed) + & (state_history.searchers.speed <= env.searcher_params.max_speed) + ) + assert state_history.searchers.speed.shape == (n_steps, env.num_agents) + assert jnp.all( + (0.0 <= state_history.searchers.heading) & (state_history.searchers.heading <= 2.0 * jnp.pi) + ) + + assert state_history.targets.pos.shape == (n_steps, env.generator.num_targets, 2) + assert jnp.all((0.0 <= state_history.targets.pos) & (state_history.targets.pos <= env_size)) + + assert timesteps.observation.positions.shape == (n_steps, env.num_agents, 2) + assert jnp.all( + (0.0 <= timesteps.observation.positions) & (timesteps.observation.positions <= 1.0) + ) + + +def test_env_does_not_smoke(multi_obs_env: SearchAndRescue) -> None: + """Test that we can run an episode without any errors.""" + multi_obs_env.time_limit = 10 + + def select_action(action_key: chex.PRNGKey, _state: Observation) -> chex.Array: + return jax.random.uniform( + action_key, (multi_obs_env.generator.num_searchers, 2), minval=-1.0, maxval=1.0 + ) + + check_env_does_not_smoke(multi_obs_env, select_action=select_action) + + +def test_env_specs_do_not_smoke(multi_obs_env: SearchAndRescue) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(multi_obs_env) + + +def test_target_detection(env: SearchAndRescue, key: chex.PRNGKey) -> None: + # Keep targets in one location + env._target_dynamics = RandomWalk(step_size=0.0) + env.generator.num_targets = 1 + + # Agent facing wrong direction should not see target + state = State( + searchers=AgentState( + pos=jnp.array([[0.5, 0.5]]), heading=jnp.array([jnp.pi]), speed=jnp.array([0.0]) + ), + targets=TargetState( + pos=jnp.array([[0.54, 0.5]]), vel=jnp.zeros((1, 2)), found=jnp.array([False]) + ), + key=key, + ) + state, timestep = env.step(state, jnp.zeros((1, 2))) + assert not state.targets.found[0] + assert timestep.reward[0] == 0 + + # Rotated agent should detect target + state = State( + searchers=AgentState( + pos=state.searchers.pos, heading=jnp.array([0.0]), speed=state.searchers.speed + ), + targets=state.targets, + key=state.key, + ) + state, timestep = env.step(state, jnp.zeros((1, 2))) + assert state.targets.found[0] + assert timestep.reward[0] == 1 + + # Searcher should only get rewards once + state, timestep = env.step(state, jnp.zeros((1, 2))) + assert state.targets.found[0] + assert timestep.reward[0] == 0 + + # Once detected target should remain detected if agent moves away + state = State( + searchers=AgentState( + pos=jnp.array([[0.0, 0.0]]), + heading=state.searchers.heading, + speed=state.searchers.speed, + ), + targets=state.targets, + key=state.key, + ) + state, timestep = env.step(state, jnp.zeros((1, 2))) + assert state.targets.found[0] + assert timestep.reward[0] == 0 + + +def test_multi_target_detection(env: SearchAndRescue, key: chex.PRNGKey) -> None: + # Keep targets in one location + env._target_dynamics = RandomWalk(step_size=0.0) + env.searcher_params = AgentParams( + max_rotate=0.1, + max_accelerate=0.01, + min_speed=0.01, + max_speed=0.05, + view_angle=0.25, + ) + env.generator.num_targets = 2 + + # Agent facing wrong direction should not see target + state = State( + searchers=AgentState( + pos=jnp.array([[0.5, 0.5]]), heading=jnp.array([0.5 * jnp.pi]), speed=jnp.array([0.0]) + ), + targets=TargetState( + pos=jnp.array([[0.54, 0.5], [0.46, 0.5]]), + vel=jnp.zeros((2, 2)), + found=jnp.array([False, False]), + ), + key=key, + ) + state, timestep = env.step(state, jnp.zeros((1, 2))) + assert not state.targets.found[0] + assert not state.targets.found[1] + assert timestep.reward[0] == 0 + + # Rotated agent should detect first target + state = State( + searchers=AgentState( + pos=state.searchers.pos, heading=jnp.array([0.0]), speed=state.searchers.speed + ), + targets=state.targets, + key=state.key, + ) + state, timestep = env.step(state, jnp.zeros((1, 2))) + assert state.targets.found[0] + assert not state.targets.found[1] + assert timestep.reward[0] == 1 + + # Rotated agent should not detect another agent + state = State( + searchers=AgentState( + pos=state.searchers.pos, heading=jnp.array([1.5 * jnp.pi]), speed=state.searchers.speed + ), + targets=state.targets, + key=state.key, + ) + state, timestep = env.step(state, jnp.zeros((1, 2))) + assert state.targets.found[0] + assert not state.targets.found[1] + assert timestep.reward[0] == 0 + + # Rotated agent again should see second agent + state = State( + searchers=AgentState( + pos=state.searchers.pos, heading=jnp.array([jnp.pi]), speed=state.searchers.speed + ), + targets=state.targets, + key=state.key, + ) + state, timestep = env.step(state, jnp.zeros((1, 2))) + assert state.targets.found[0] + assert state.targets.found[1] + assert timestep.reward[0] == 1 + + +def test_search_and_rescue_render(monkeypatch: pytest.MonkeyPatch, env: SearchAndRescue) -> None: + """Check that the render method builds the figure but does not display it.""" + monkeypatch.setattr(plt, "show", lambda fig: None) + step_fn = jax.jit(env.step) + state, timestep = env.reset(jax.random.PRNGKey(0)) + action = env.action_spec.generate_value() + state, timestep = step_fn(state, action) + env.render(state) + env.close() + + +def test_search_and_rescue__animation(env: SearchAndRescue, tmpdir: py.path.local) -> None: + """Check that the animation method creates the animation correctly and can save to a gif.""" + step_fn = jax.jit(env.step) + state, _ = env.reset(jax.random.PRNGKey(0)) + states = [state] + action = env.action_spec.generate_value() + state, _ = step_fn(state, action) + states.append(state) + animation = env.animate(states, interval=200, save_path=None) + assert isinstance(animation, matplotlib.animation.Animation) + + path = str(tmpdir.join("/anim.gif")) + animation.save(path, writer=matplotlib.animation.PillowWriter(fps=10), dpi=60) diff --git a/jumanji/environments/swarms/search_and_rescue/generator.py b/jumanji/environments/swarms/search_and_rescue/generator.py new file mode 100644 index 000000000..2425475f6 --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/generator.py @@ -0,0 +1,97 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +import chex +import jax +import jax.numpy as jnp + +from jumanji.environments.swarms.common.types import AgentParams, AgentState +from jumanji.environments.swarms.search_and_rescue.types import State, TargetState + + +class Generator(abc.ABC): + def __init__(self, num_searchers: int, num_targets: int, env_size: float = 1.0) -> None: + """Interface for instance generation for the `SearchAndRescue` environment. + + Args: + num_searchers: Number of searcher agents + num_targets: Number of search targets + env_size: Size (dimensions of the environment), default 1.0. + """ + self.num_searchers = num_searchers + self.num_targets = num_targets + self.env_size = env_size + + @abc.abstractmethod + def __call__(self, key: chex.PRNGKey, searcher_params: AgentParams) -> State: + """Generate initial agent positions and velocities. + + Args: + key: random key. + searcher_params: Searcher aagent `AgentParams`. + + Returns: + Initial agent `State`. + """ + + +class RandomGenerator(Generator): + def __call__(self, key: chex.PRNGKey, searcher_params: AgentParams) -> State: + """Generate random initial agent positions and velocities, and random target positions. + + Args: + key: random key. + searcher_params: Searcher `AgentParams`. + + Returns: + state: the generated state. + """ + key, searcher_key, target_key = jax.random.split(key, num=3) + + k_pos, k_head, k_speed = jax.random.split(searcher_key, 3) + positions = jax.random.uniform( + k_pos, (self.num_searchers, 2), minval=0.0, maxval=self.env_size + ) + headings = jax.random.uniform( + k_head, (self.num_searchers,), minval=0.0, maxval=2.0 * jnp.pi + ) + speeds = jax.random.uniform( + k_speed, + (self.num_searchers,), + minval=searcher_params.min_speed, + maxval=searcher_params.max_speed, + ) + searcher_state = AgentState( + pos=positions, + speed=speeds, + heading=headings, + ) + + target_pos = jax.random.uniform( + target_key, (self.num_targets, 2), minval=0.0, maxval=self.env_size + ) + target_vel = jnp.zeros((self.num_targets, 2)) + + state = State( + searchers=searcher_state, + targets=TargetState( + pos=target_pos, + vel=target_vel, + found=jnp.full((self.num_targets,), False, dtype=bool), + ), + key=key, + ) + return state diff --git a/jumanji/environments/swarms/search_and_rescue/generator_test.py b/jumanji/environments/swarms/search_and_rescue/generator_test.py new file mode 100644 index 000000000..e6552359c --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/generator_test.py @@ -0,0 +1,44 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import chex +import jax.numpy as jnp +import pytest + +from jumanji.environments.swarms.common.types import AgentParams +from jumanji.environments.swarms.search_and_rescue.generator import Generator, RandomGenerator +from jumanji.environments.swarms.search_and_rescue.types import State + + +@pytest.mark.parametrize("env_size", [1.0, 0.5]) +def test_random_generator(key: chex.PRNGKey, env_size: float) -> None: + params = AgentParams( + max_rotate=0.5, + max_accelerate=0.01, + min_speed=0.01, + max_speed=0.05, + view_angle=0.5, + ) + generator = RandomGenerator(num_searchers=100, num_targets=101, env_size=env_size) + + assert isinstance(generator, Generator) + + state = generator(key, params) + + assert isinstance(state, State) + assert state.searchers.pos.shape == (generator.num_searchers, 2) + assert jnp.all(0.0 <= state.searchers.pos) and jnp.all(state.searchers.pos <= env_size) + assert state.targets.pos.shape == (generator.num_targets, 2) + assert jnp.all(0.0 <= state.targets.pos) and jnp.all(state.targets.pos <= env_size) + assert not jnp.any(state.targets.found) + assert state.step == 0 diff --git a/jumanji/environments/swarms/search_and_rescue/observations.py b/jumanji/environments/swarms/search_and_rescue/observations.py new file mode 100644 index 000000000..6a7e587c0 --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/observations.py @@ -0,0 +1,392 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Tuple + +import chex +import esquilax +import jax.numpy as jnp + +from jumanji.environments.swarms.common.types import AgentState +from jumanji.environments.swarms.common.updates import angular_width, view, view_reduction_fn +from jumanji.environments.swarms.search_and_rescue.types import State, TargetState + + +def view_reduction(view_shape: Tuple[int, ...]) -> esquilax.reductions.Reduction: + return esquilax.reductions.Reduction( + fn=view_reduction_fn, + id=-jnp.ones(view_shape), + ) + + +class ObservationFn(abc.ABC): + def __init__( + self, + num_channels: int, + num_vision: int, + vision_range: float, + view_angle: float, + agent_radius: float, + env_size: float, + ) -> None: + """ + Base class for observation function mapping state to individual agent views. + + Maps states to an array of individual local agent views of + the environment, with shape (n-agents, n-channels, n-vision). + Channels can be used to differentiate between agent types or + statuses. + + Args: + num_channels: Number of channels in agent view. + num_vision: Size of vision array. + vision_range: Vision range. + view_angle: Agent view angle (as a fraction of π). + agent_radius: Agent/target visual radius. + env_size: Environment size. + """ + self.num_channels = num_channels + self.num_vision = num_vision + self.vision_range = vision_range + self.view_angle = view_angle + self.agent_radius = agent_radius + self.env_size = env_size + + @abc.abstractmethod + def __call__(self, state: State) -> chex.Array: + """ + Generate agent view/observation from state + + Args: + state: Current simulation state + + Returns: + Array of individual agent views (n-agents, n-channels, n-vision). + """ + + +class AgentObservationFn(ObservationFn): + def __init__( + self, + num_vision: int, + vision_range: float, + view_angle: float, + agent_radius: float, + env_size: float, + ) -> None: + """ + Observation that only visualises other search agents in proximity. + + Args: + num_vision: Size of vision array. + vision_range: Vision range. + view_angle: Agent view angle (as a fraction of π). + agent_radius: Agent/target visual radius. + env_size: Environment size. + """ + super().__init__( + 1, + num_vision, + vision_range, + view_angle, + agent_radius, + env_size, + ) + + def __call__(self, state: State) -> chex.Array: + """ + Generate agent view/observation from state + + Args: + state: Current simulation state + + Returns: + Array of individual agent views of shape + (n-agents, 1, n-vision). + """ + searcher_views = esquilax.transforms.spatial( + view, + reduction=view_reduction((self.num_vision,)), + include_self=False, + i_range=self.vision_range, + dims=self.env_size, + )( + (self.view_angle, self.agent_radius), + state.searchers, + state.searchers, + pos=state.searchers.pos, + n_view=self.num_vision, + i_range=self.vision_range, + env_size=self.env_size, + ) + return searcher_views[:, jnp.newaxis] + + +def found_target_view( + params: Tuple[float, float], + searcher: AgentState, + target: TargetState, + *, + n_view: int, + i_range: float, + env_size: float, +) -> chex.Array: + """ + Return view of a target, dependent on target status. + + This function is intended to be mapped over agents-target + pairs by Esquilax. + + Args: + params: View angle and target visual radius. + searcher: Searcher agent state + target: Target state + n_view: Number of values in view array. + i_range: Vision range + env_size: Environment size + + Returns: + Segmented agent view of target. + """ + view_angle, agent_radius = params + rays = jnp.linspace( + -view_angle * jnp.pi, + view_angle * jnp.pi, + n_view, + endpoint=True, + ) + d, left, right = angular_width( + searcher.pos, + target.pos, + searcher.heading, + i_range, + agent_radius, + env_size, + ) + checks = jnp.logical_and(target.found, jnp.logical_and(left < rays, rays < right)) + obs = jnp.where(checks, d, -1.0) + return obs + + +class AgentAndTargetObservationFn(ObservationFn): + def __init__( + self, + num_vision: int, + vision_range: float, + view_angle: float, + agent_radius: float, + env_size: float, + ) -> None: + """ + Vision model that contains other agents and found targets. + + Searchers and targets are visualised as individual channels. + Targets are only included if they have been located already. + + Args: + num_vision: Size of vision array. + vision_range: Vision range. + view_angle: Agent view angle (as a fraction of π). + agent_radius: Agent/target visual radius. + env_size: Environment size. + """ + self.vision_range = vision_range + self.view_angle = view_angle + self.agent_radius = agent_radius + self.env_size = env_size + super().__init__( + 2, + num_vision, + vision_range, + view_angle, + agent_radius, + env_size, + ) + + def __call__(self, state: State) -> chex.Array: + """ + Generate agent view/observation from state + + Args: + state: Current simulation state + + Returns: + Array of individual agent views of shape + (n-agents, 2, n-vision). Other agents are shown + in channel 0, and located targets in channel 1. + """ + searcher_views = esquilax.transforms.spatial( + view, + reduction=view_reduction((self.num_vision,)), + include_self=False, + i_range=self.vision_range, + dims=self.env_size, + )( + (self.view_angle, self.agent_radius), + state.searchers, + state.searchers, + pos=state.searchers.pos, + n_view=self.num_vision, + i_range=self.vision_range, + env_size=self.env_size, + ) + target_views = esquilax.transforms.spatial( + found_target_view, + reduction=view_reduction((self.num_vision,)), + include_self=False, + i_range=self.vision_range, + dims=self.env_size, + )( + (self.view_angle, self.agent_radius), + state.searchers, + state.targets, + pos=state.searchers.pos, + pos_b=state.targets.pos, + n_view=self.num_vision, + i_range=self.vision_range, + env_size=self.env_size, + ) + return jnp.hstack([searcher_views[:, jnp.newaxis], target_views[:, jnp.newaxis]]) + + +def all_target_view( + params: Tuple[float, float], + searcher: AgentState, + target: TargetState, + *, + n_view: int, + i_range: float, + env_size: float, +) -> chex.Array: + """ + Return view of a target, dependent on target status. + + This function is intended to be mapped over agents target + pairs by Esquilax. + + Args: + params: View angle and target visual radius. + searcher: Searcher agent state. + target: Target state. + n_view: Number of value sin view array. + i_range: Vision range. + env_size: Environment size. + + Returns: + Segmented agent view of target. + """ + view_angle, agent_radius = params + rays = jnp.linspace( + -view_angle * jnp.pi, + view_angle * jnp.pi, + n_view, + endpoint=True, + ) + d, left, right = angular_width( + searcher.pos, + target.pos, + searcher.heading, + i_range, + agent_radius, + env_size, + ) + ray_checks = jnp.logical_and(left < rays, rays < right) + checks_a = jnp.logical_and(target.found, ray_checks) + checks_b = jnp.logical_and(~target.found, ray_checks) + obs = [jnp.where(checks_a, d, -1.0), jnp.where(checks_b, d, -1.0)] + obs = jnp.vstack(obs) + return obs + + +class AgentAndAllTargetObservationFn(ObservationFn): + def __init__( + self, + num_vision: int, + vision_range: float, + view_angle: float, + agent_radius: float, + env_size: float, + ) -> None: + """ + Vision model that contains other agents, and all targets. + + Searchers and targets are visualised as individual channels, + with found and unfound targets also shown on different channels. + + Args: + num_vision: Size of vision array. + vision_range: Vision range. + view_angle: Agent view angle (as a fraction of π). + agent_radius: Agent/target visual radius. + env_size: Environment size. + """ + self.vision_range = vision_range + self.view_angle = view_angle + self.agent_radius = agent_radius + self.env_size = env_size + super().__init__( + 3, + num_vision, + vision_range, + view_angle, + agent_radius, + env_size, + ) + + def __call__(self, state: State) -> chex.Array: + """ + Generate agent view/observation from state + + Args: + state: Current simulation state + + Returns: + Array of individual agent views of shape + (n-agents, 3, n-vision). Other agents are shown + in channel 0, located targets channel 1, and un-located + targets in channel 2. + """ + searcher_views = esquilax.transforms.spatial( + view, + reduction=view_reduction((self.num_vision,)), + include_self=False, + i_range=self.vision_range, + dims=self.env_size, + )( + (self.view_angle, self.agent_radius), + state.searchers, + state.searchers, + pos=state.searchers.pos, + n_view=self.num_vision, + i_range=self.vision_range, + env_size=self.env_size, + ) + target_views = esquilax.transforms.spatial( + all_target_view, + reduction=view_reduction((2, self.num_vision)), + include_self=False, + i_range=self.vision_range, + dims=self.env_size, + )( + (self.view_angle, self.agent_radius), + state.searchers, + state.targets, + pos=state.searchers.pos, + pos_b=state.targets.pos, + n_view=self.num_vision, + i_range=self.vision_range, + env_size=self.env_size, + ) + return jnp.hstack([searcher_views[:, jnp.newaxis], target_views]) diff --git a/jumanji/environments/swarms/search_and_rescue/observations_test.py b/jumanji/environments/swarms/search_and_rescue/observations_test.py new file mode 100644 index 000000000..a95dcd271 --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/observations_test.py @@ -0,0 +1,355 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple + +import chex +import jax.numpy as jnp +import pytest + +from jumanji.environments.swarms.common.types import AgentState +from jumanji.environments.swarms.search_and_rescue import observations +from jumanji.environments.swarms.search_and_rescue.types import State, TargetState + +VISION_RANGE = 0.2 +VIEW_ANGLE = 0.5 + + +@pytest.mark.parametrize( + "searcher_positions, searcher_headings, env_size, view_updates", + [ + # Both out of view range + ([[0.8, 0.5], [0.2, 0.5]], [jnp.pi, 0.0], 1.0, []), + # Both view each other + ([[0.25, 0.5], [0.2, 0.5]], [jnp.pi, 0.0], 1.0, [(0, 5, 0.25), (1, 5, 0.25)]), + # One facing wrong direction + ( + [[0.25, 0.5], [0.2, 0.5]], + [jnp.pi, jnp.pi], + 1.0, + [(0, 5, 0.25)], + ), + # Only see closest neighbour + ( + [[0.35, 0.5], [0.25, 0.5], [0.2, 0.5]], + [jnp.pi, 0.0, 0.0], + 1.0, + [(0, 5, 0.5), (1, 5, 0.5), (2, 5, 0.25)], + ), + # Observed around wrapped edge + ( + [[0.025, 0.5], [0.975, 0.5]], + [jnp.pi, 0.0], + 1.0, + [(0, 5, 0.25), (1, 5, 0.25)], + ), + # Observed around wrapped edge of smaller env + ( + [[0.025, 0.25], [0.475, 0.25]], + [jnp.pi, 0.0], + 0.5, + [(0, 5, 0.25), (1, 5, 0.25)], + ), + ], +) +def test_searcher_view( + key: chex.PRNGKey, + searcher_positions: List[List[float]], + searcher_headings: List[float], + env_size: float, + view_updates: List[Tuple[int, int, float]], +) -> None: + """ + Test agent-only view model generates expected array with different + configurations of agents. + """ + + searcher_positions = jnp.array(searcher_positions) + searcher_headings = jnp.array(searcher_headings) + searcher_speed = jnp.zeros(searcher_headings.shape) + + state = State( + searchers=AgentState( + pos=searcher_positions, heading=searcher_headings, speed=searcher_speed + ), + targets=TargetState( + pos=jnp.zeros((1, 2)), vel=jnp.zeros((1, 2)), found=jnp.zeros((1, 2), dtype=bool) + ), + key=key, + ) + + observe_fn = observations.AgentObservationFn( + num_vision=11, + vision_range=VISION_RANGE, + view_angle=VIEW_ANGLE, + agent_radius=0.01, + env_size=env_size, + ) + + obs = observe_fn(state) + + expected = jnp.full((searcher_headings.shape[0], 1, observe_fn.num_vision), -1.0) + + for i, idx, val in view_updates: + expected = expected.at[i, 0, idx].set(val) + + assert jnp.all(jnp.isclose(obs, expected)) + + +@pytest.mark.parametrize( + "searcher_positions, searcher_headings, env_size, view_updates", + [ + # Both out of view range + ([[0.8, 0.5], [0.2, 0.5]], [jnp.pi, 0.0], 1.0, []), + # Both view each other + ([[0.25, 0.5], [0.2, 0.5]], [jnp.pi, 0.0], 1.0, [(0, 5, 0.25), (1, 5, 0.25)]), + # One facing wrong direction + ( + [[0.25, 0.5], [0.2, 0.5]], + [jnp.pi, jnp.pi], + 1.0, + [(0, 5, 0.25)], + ), + # Only see closest neighbour + ( + [[0.35, 0.5], [0.25, 0.5], [0.2, 0.5]], + [jnp.pi, 0.0, 0.0], + 1.0, + [(0, 5, 0.5), (1, 5, 0.5), (2, 5, 0.25)], + ), + # Observed around wrapped edge + ( + [[0.025, 0.5], [0.975, 0.5]], + [jnp.pi, 0.0], + 1.0, + [(0, 5, 0.25), (1, 5, 0.25)], + ), + # Observed around wrapped edge of smaller env + ( + [[0.025, 0.25], [0.475, 0.25]], + [jnp.pi, 0.0], + 0.5, + [(0, 5, 0.25), (1, 5, 0.25)], + ), + ], +) +def test_search_and_target_view_searchers( + key: chex.PRNGKey, + searcher_positions: List[List[float]], + searcher_headings: List[float], + env_size: float, + view_updates: List[Tuple[int, int, float]], +) -> None: + """ + Test agent+target view model generates expected array with different + configurations of agents only. + """ + + n_agents = len(searcher_headings) + searcher_positions = jnp.array(searcher_positions) + searcher_headings = jnp.array(searcher_headings) + searcher_speed = jnp.zeros(searcher_headings.shape) + + state = State( + searchers=AgentState( + pos=searcher_positions, heading=searcher_headings, speed=searcher_speed + ), + targets=TargetState( + pos=jnp.zeros((1, 2)), vel=jnp.zeros((1, 2)), found=jnp.zeros((1,), dtype=bool) + ), + key=key, + ) + + observe_fn = observations.AgentAndTargetObservationFn( + num_vision=11, + vision_range=VISION_RANGE, + view_angle=VIEW_ANGLE, + agent_radius=0.01, + env_size=env_size, + ) + + obs = observe_fn(state) + assert obs.shape == (n_agents, 2, observe_fn.num_vision) + + expected = jnp.full((n_agents, 2, observe_fn.num_vision), -1.0) + + for i, idx, val in view_updates: + expected = expected.at[i, 0, idx].set(val) + + assert jnp.all(jnp.isclose(obs, expected)) + + +@pytest.mark.parametrize( + "searcher_position, searcher_heading, target_position, target_found, env_size, view_updates", + [ + # Target out of view range + ([0.8, 0.5], jnp.pi, [0.2, 0.5], True, 1.0, []), + # Target in view and found + ([0.25, 0.5], jnp.pi, [0.2, 0.5], True, 1.0, [(5, 0.25)]), + # Target in view but not found + ([0.25, 0.5], jnp.pi, [0.2, 0.5], False, 1.0, []), + # Observed around wrapped edge + ( + [0.025, 0.5], + jnp.pi, + [0.975, 0.5], + True, + 1.0, + [(5, 0.25)], + ), + # Observed around wrapped edge of smaller env + ( + [0.025, 0.25], + jnp.pi, + [0.475, 0.25], + True, + 0.5, + [(5, 0.25)], + ), + ], +) +def test_search_and_target_view_targets( + key: chex.PRNGKey, + searcher_position: List[float], + searcher_heading: float, + target_position: List[float], + target_found: bool, + env_size: float, + view_updates: List[Tuple[int, float]], +) -> None: + """ + Test agent+target view model generates expected array with different + configurations of targets only. + """ + + searcher_position = jnp.array([searcher_position]) + searcher_heading = jnp.array([searcher_heading]) + searcher_speed = jnp.zeros((1,)) + target_position = jnp.array([target_position]) + target_found = jnp.array([target_found]) + + state = State( + searchers=AgentState(pos=searcher_position, heading=searcher_heading, speed=searcher_speed), + targets=TargetState( + pos=target_position, + vel=jnp.zeros_like(target_position), + found=target_found, + ), + key=key, + ) + + observe_fn = observations.AgentAndTargetObservationFn( + num_vision=11, + vision_range=VISION_RANGE, + view_angle=VIEW_ANGLE, + agent_radius=0.01, + env_size=env_size, + ) + + obs = observe_fn(state) + assert obs.shape == (1, 2, observe_fn.num_vision) + + expected = jnp.full((1, 2, observe_fn.num_vision), -1.0) + + for idx, val in view_updates: + expected = expected.at[0, 1, idx].set(val) + + assert jnp.all(jnp.isclose(obs, expected)) + + +@pytest.mark.parametrize( + "searcher_position, searcher_heading, target_position, target_found, env_size, view_updates", + [ + # Target out of view range + ([0.8, 0.5], jnp.pi, [0.2, 0.5], True, 1.0, []), + # Target in view and found + ([0.25, 0.5], jnp.pi, [0.2, 0.5], True, 1.0, [(1, 5, 0.25)]), + # Target in view but not found + ([0.25, 0.5], jnp.pi, [0.2, 0.5], False, 1.0, [(2, 5, 0.25)]), + # Observed around wrapped edge found + ( + [0.025, 0.5], + jnp.pi, + [0.975, 0.5], + True, + 1.0, + [(1, 5, 0.25)], + ), + # Observed around wrapped edge not found + ( + [0.025, 0.5], + jnp.pi, + [0.975, 0.5], + False, + 1.0, + [(2, 5, 0.25)], + ), + # Observed around wrapped edge of smaller env + ( + [0.025, 0.25], + jnp.pi, + [0.475, 0.25], + True, + 0.5, + [(1, 5, 0.25)], + ), + ], +) +def test_search_and_all_target_view_targets( + key: chex.PRNGKey, + searcher_position: List[float], + searcher_heading: float, + target_position: List[float], + target_found: bool, + env_size: float, + view_updates: List[Tuple[int, int, float]], +) -> None: + """ + Test agent+target view model generates expected array with different + configurations of targets only. + """ + + searcher_position = jnp.array([searcher_position]) + searcher_heading = jnp.array([searcher_heading]) + searcher_speed = jnp.zeros((1,)) + target_position = jnp.array([target_position]) + target_found = jnp.array([target_found]) + + state = State( + searchers=AgentState(pos=searcher_position, heading=searcher_heading, speed=searcher_speed), + targets=TargetState( + pos=target_position, + vel=jnp.zeros_like(target_position), + found=target_found, + ), + key=key, + ) + + observe_fn = observations.AgentAndAllTargetObservationFn( + num_vision=11, + vision_range=VISION_RANGE, + view_angle=VIEW_ANGLE, + agent_radius=0.01, + env_size=env_size, + ) + + obs = observe_fn(state) + assert obs.shape == (1, 3, observe_fn.num_vision) + + expected = jnp.full((1, 3, observe_fn.num_vision), -1.0) + + for i, idx, val in view_updates: + expected = expected.at[0, i, idx].set(val) + + assert jnp.all(jnp.isclose(obs, expected)) diff --git a/jumanji/environments/swarms/search_and_rescue/reward.py b/jumanji/environments/swarms/search_and_rescue/reward.py new file mode 100644 index 000000000..756fec884 --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/reward.py @@ -0,0 +1,109 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +import chex +import jax.numpy as jnp + + +class RewardFn(abc.ABC): + """Abstract class for `SearchAndRescue` rewards.""" + + @abc.abstractmethod + def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array: + """The reward function used in the `SearchAndRescue` environment. + + Args: + found_targets: Array of boolean flags indicating if an + agent has found a target. + + Returns: + Individual reward for each agent. + """ + + +def _normalise_rewards(rewards: chex.Array) -> chex.Array: + norms = jnp.sum(rewards, axis=0)[jnp.newaxis] + rewards = jnp.where(norms > 0, rewards / norms, rewards) + return rewards + + +def _scale_rewards(rewards: chex.Array, step: int, time_limit: int) -> chex.Array: + scale = (time_limit - step) / time_limit + return scale * rewards + + +class SharedRewardFn(RewardFn): + """ + Calculate per agent rewards from detected targets + + Targets detected by multiple agents share rewards. Agents + can receive rewards for detecting multiple targets. + """ + + def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array: + rewards = found_targets.astype(float) + rewards = _normalise_rewards(rewards) + rewards = jnp.sum(rewards, axis=1) + return rewards + + +class SharedScaledRewardFn(RewardFn): + """ + Calculate per agent rewards from detected targets and scale by timestep + + Targets detected by multiple agents share rewards. Agents + can receive rewards for detecting multiple targets. + Rewards are linearly scaled by the current time step such that + rewards received are 0 at the final step. + """ + + def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array: + rewards = found_targets.astype(float) + rewards = _normalise_rewards(rewards) + rewards = jnp.sum(rewards, axis=1) + rewards = _scale_rewards(rewards, step, time_limit) + return rewards + + +class IndividualRewardFn(RewardFn): + """ + Calculate per agent rewards from detected targets + + Each agent that detects a target receives a +1 reward + even if a target is detected by multiple agents. + """ + + def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array: + rewards = found_targets.astype(float) + rewards = jnp.sum(rewards, axis=1) + return rewards + + +class IndividualScaledRewardFn(RewardFn): + """ + Calculate per agent rewards from detected targets and scale by timestep + + Each agent that detects a target receives a +1 reward + even if a target is detected by multiple agents. + Rewards are linearly scaled by the current time step such that + rewards received are 0 at the final step. + """ + + def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array: + rewards = found_targets.astype(float) + rewards = jnp.sum(rewards, axis=1) + rewards = _scale_rewards(rewards, step, time_limit) + return rewards diff --git a/jumanji/environments/swarms/search_and_rescue/reward_test.py b/jumanji/environments/swarms/search_and_rescue/reward_test.py new file mode 100644 index 000000000..bd49f69e8 --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/reward_test.py @@ -0,0 +1,71 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import chex +import jax.numpy as jnp +import pytest + +from jumanji.environments.swarms.search_and_rescue import reward + + +@pytest.fixture +def target_states() -> chex.Array: + return jnp.array([[False, True, True], [False, False, True]], dtype=bool) + + +def test_shared_rewards(target_states: chex.Array) -> None: + shared_rewards = reward.SharedRewardFn()(target_states, 0, 10) + + assert shared_rewards.shape == (2,) + assert shared_rewards.dtype == jnp.float32 + assert jnp.allclose(shared_rewards, jnp.array([1.5, 0.5])) + + +def test_individual_rewards(target_states: chex.Array) -> None: + individual_rewards = reward.IndividualRewardFn()(target_states, 0, 10) + + assert individual_rewards.shape == (2,) + assert individual_rewards.dtype == jnp.float32 + assert jnp.allclose(individual_rewards, jnp.array([2.0, 1.0])) + + +def test_shared_scaled_rewards(target_states: chex.Array) -> None: + reward_fn = reward.SharedScaledRewardFn() + + shared_scaled_rewards = reward_fn(target_states, 0, 10) + + assert shared_scaled_rewards.shape == (2,) + assert shared_scaled_rewards.dtype == jnp.float32 + assert jnp.allclose(shared_scaled_rewards, jnp.array([1.5, 0.5])) + + shared_scaled_rewards = reward_fn(target_states, 10, 10) + + assert shared_scaled_rewards.shape == (2,) + assert shared_scaled_rewards.dtype == jnp.float32 + assert jnp.allclose(shared_scaled_rewards, jnp.array([0.0, 0.0])) + + +def test_individual_scaled_rewards(target_states: chex.Array) -> None: + reward_fn = reward.IndividualScaledRewardFn() + + individual_scaled_rewards = reward_fn(target_states, 0, 10) + + assert individual_scaled_rewards.shape == (2,) + assert individual_scaled_rewards.dtype == jnp.float32 + assert jnp.allclose(individual_scaled_rewards, jnp.array([2.0, 1.0])) + + individual_scaled_rewards = reward_fn(target_states, 10, 10) + + assert individual_scaled_rewards.shape == (2,) + assert individual_scaled_rewards.dtype == jnp.float32 + assert jnp.allclose(individual_scaled_rewards, jnp.array([0.0, 0.0])) diff --git a/jumanji/environments/swarms/search_and_rescue/types.py b/jumanji/environments/swarms/search_and_rescue/types.py new file mode 100644 index 000000000..3fba76438 --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/types.py @@ -0,0 +1,85 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING, NamedTuple + +if TYPE_CHECKING: + from dataclasses import dataclass +else: + from chex import dataclass + +import chex + +from jumanji.environments.swarms.common.types import AgentState + + +@dataclass +class TargetState: + """ + The state of the rescue targets. + + pos: 2d position of the target agents. + velocity: 2d velocity of the target agents. + found: Boolean flag indicating if the + target has been located by a searcher. + """ + + pos: chex.Array # (num_targets, 2) + vel: chex.Array # (num_targets, 2) + found: chex.Array # (num_targets,) + + +@dataclass +class State: + """ + searchers: Searcher agent states. + targets: Search target state. + key: JAX random key. + step: Environment step number + """ + + searchers: AgentState + targets: TargetState + key: chex.PRNGKey + step: int = 0 + + +class Observation(NamedTuple): + """ + Individual observations for searching agents and information + on number of remaining steps and ratio of targets to be found. + + Each agent generates an independent observation, an array of + values representing the distance along a ray from the agent to + the nearest neighbour, with each cell representing a ray angle + (with `num_vision` rays evenly distributed over the agents + field of vision). + + The co-ordinates of each agent are also included in the + observation for debug and use in global observations. + + For example if an agent sees another agent straight ahead and + `num_vision = 5` then the observation array could be + + ``` + [-1.0, -1.0, 0.5, -1.0, -1.0] + ``` + + where `-1.0` indicates there is no agents along that ray, + and `0.5` is the normalised distance to the other agent. + """ + + searcher_views: chex.Array # (num_searchers, num_vision) + targets_remaining: chex.Numeric # () + step: chex.Numeric # () + positions: chex.Array # (num_searchers, 2) diff --git a/jumanji/environments/swarms/search_and_rescue/utils.py b/jumanji/environments/swarms/search_and_rescue/utils.py new file mode 100644 index 000000000..ed52b33c5 --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/utils.py @@ -0,0 +1,84 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import chex +import jax.numpy as jnp +from esquilax.utils import shortest_vector + +from jumanji.environments.swarms.common.types import AgentState +from jumanji.environments.swarms.search_and_rescue.types import TargetState + + +def _check_target_in_view( + searcher_pos: chex.Array, + target_pos: chex.Array, + searcher_heading: chex.Array, + searcher_view_angle: float, + env_size: float, +) -> chex.Array: + """ + Check if a target is inside the view-cone of a searcher. + + Args: + searcher_pos: Searcher position. + target_pos: Target position. + searcher_heading: Searcher heading angle. + searcher_view_angle: Searcher view angle. + env_size: Size of the environment. + + Returns: + bool: Flag indicating if a target is within view. + """ + dx = shortest_vector(searcher_pos, target_pos, length=env_size) + phi = jnp.arctan2(dx[1], dx[0]) % (2 * jnp.pi) + dh = shortest_vector(phi, searcher_heading, 2 * jnp.pi) + searcher_view_angle = searcher_view_angle * jnp.pi + return (dh >= -searcher_view_angle) & (dh <= searcher_view_angle) + + +def searcher_detect_targets( + searcher_view_angle: float, + searcher: AgentState, + target: Tuple[chex.Array, TargetState], + *, + env_size: float, + n_targets: int, +) -> chex.Array: + """ + Return array of flags indicating if a target has been located + + Sets the flag at the target index if the target is within the + searchers view cone, and has not already been detected. + + Args: + searcher_view_angle: View angle of searching agents + representing a fraction of π from the agents heading. + searcher: State of the searching agent (i.e. the agent + position and heading) + target: Index and State of the target (i.e. its position and + search status). + env_size: size of the environment. + n_targets: Number of search targets (static). + + Returns: + array of boolean flags, set if a target at the index has been found. + """ + target_idx, target = target + target_found = jnp.zeros((n_targets,), dtype=bool) + can_see = _check_target_in_view( + searcher.pos, target.pos, searcher.heading, searcher_view_angle, env_size + ) + return target_found.at[target_idx].set(jnp.logical_and(~target.found, can_see)) diff --git a/jumanji/environments/swarms/search_and_rescue/utils_test.py b/jumanji/environments/swarms/search_and_rescue/utils_test.py new file mode 100644 index 000000000..af4f11a37 --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/utils_test.py @@ -0,0 +1,92 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import List + +import chex +import jax +import jax.numpy as jnp +import pytest + +from jumanji.environments.swarms.common.types import AgentState +from jumanji.environments.swarms.search_and_rescue.dynamics import RandomWalk, TargetDynamics +from jumanji.environments.swarms.search_and_rescue.types import TargetState +from jumanji.environments.swarms.search_and_rescue.utils import ( + searcher_detect_targets, +) + + +def test_random_walk_dynamics(key: chex.PRNGKey) -> None: + n_targets = 50 + pos_0 = jnp.full((n_targets, 2), 0.5) + + s0 = TargetState( + pos=pos_0, vel=jnp.zeros((n_targets, 2)), found=jnp.zeros((n_targets,), dtype=bool) + ) + + dynamics = RandomWalk(0.1) + assert isinstance(dynamics, TargetDynamics) + s1 = dynamics(key, s0, 1.0) + + assert isinstance(s1, TargetState) + assert s1.pos.shape == (n_targets, 2) + assert jnp.array_equal(s0.found, s1.found) + assert jnp.all(jnp.abs(s0.pos - s1.pos) < 0.1) + + +@pytest.mark.parametrize( + "pos, heading, view_angle, target_state, expected, env_size", + [ + ([0.1, 0.0], 0.0, 0.5, False, False, 1.0), + ([0.1, 0.0], jnp.pi, 0.5, False, True, 1.0), + ([0.1, 0.0], jnp.pi, 0.5, True, False, 1.0), + ([0.9, 0.0], jnp.pi, 0.5, False, False, 1.0), + ([0.9, 0.0], 0.0, 0.5, False, True, 1.0), + ([0.9, 0.0], 0.0, 0.5, True, False, 1.0), + ([0.0, 0.1], 1.5 * jnp.pi, 0.5, True, False, 1.0), + ([0.1, 0.0], 0.5 * jnp.pi, 0.5, False, True, 1.0), + ([0.1, 0.0], 0.5 * jnp.pi, 0.4, False, False, 1.0), + ([0.4, 0.0], 0.0, 0.5, False, False, 1.0), + ([0.4, 0.0], 0.0, 0.5, False, True, 0.5), + ], +) +def test_target_found( + pos: List[float], + heading: float, + view_angle: float, + target_state: bool, + expected: bool, + env_size: float, +) -> None: + target = TargetState( + pos=jnp.zeros((2,)), + vel=jnp.zeros((2,)), + found=target_state, + ) + + searcher = AgentState( + pos=jnp.array(pos), + heading=heading, + speed=0.0, + ) + + found = jax.jit(partial(searcher_detect_targets, env_size=env_size, n_targets=1))( + view_angle, + searcher, + (jnp.arange(1), target), + ) + + assert found.shape == (1,) + assert found[0] == expected diff --git a/jumanji/environments/swarms/search_and_rescue/viewer.py b/jumanji/environments/swarms/search_and_rescue/viewer.py new file mode 100644 index 000000000..81655820f --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/viewer.py @@ -0,0 +1,163 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Sequence, Tuple + +import jax.numpy as jnp +import matplotlib.animation +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.artist import Artist +from matplotlib.layout_engine import TightLayoutEngine + +import jumanji +import jumanji.environments +from jumanji.environments.swarms.common.viewer import draw_agents, format_plot +from jumanji.environments.swarms.search_and_rescue.types import State +from jumanji.viewer import Viewer + + +class SearchAndRescueViewer(Viewer[State]): + def __init__( + self, + figure_name: str = "SearchAndRescue", + figure_size: Tuple[float, float] = (6.0, 6.0), + searcher_color: str = "blue", + target_found_color: str = "green", + target_lost_color: str = "red", + env_size: Tuple[float, float] = (1.0, 1.0), + ) -> None: + """Viewer for the `SearchAndRescue` environment. + + Args: + figure_name: The window name to be used when initialising the window. + figure_size: Tuple (height, width) of the matplotlib figure window. + searcher_color: Color of searcher agent markers (arrows). + target_found_color: Color of target markers when they have been found. + target_lost_color: Color of target markers when they are still to be found. + env_size: Tuple environment spatial dimensions, used to set the plot region. + """ + self._figure_name = figure_name + self._figure_size = figure_size + self.searcher_color = searcher_color + self.target_colors = np.array([target_lost_color, target_found_color]) + self._animation: Optional[matplotlib.animation.Animation] = None + self.env_size = env_size + + def render(self, state: State) -> None: + """Render a frame of the environment for a given state using matplotlib. + + Args: + state: State object containing the current dynamics of the environment. + """ + self._clear_display() + fig, ax = self._get_fig_ax() + self._draw(ax, state) + self._update_display(fig) + + def animate( + self, states: Sequence[State], interval: int, save_path: Optional[str] + ) -> matplotlib.animation.FuncAnimation: + """Create an animation from a sequence of states. + + Args: + states: sequence of `State` corresponding to subsequent timesteps. + interval: delay between frames in milliseconds, default to 200. + save_path: the path where the animation file should be saved. If it is None, the plot + will not be saved. + + Returns: + Animation object that can be saved as a GIF, MP4, or rendered with HTML. + """ + if not states: + raise ValueError(f"The states argument has to be non-empty, got {states}.") + fig, ax = plt.subplots(num=f"{self._figure_name}Anim", figsize=self._figure_size) + fig, ax = format_plot(fig, ax, self.env_size) + + searcher_quiver = draw_agents(ax, states[0].searchers, self.searcher_color) + target_scatter = ax.scatter( + states[0].targets.pos[:, 0], states[0].targets.pos[:, 1], marker="o" + ) + + def make_frame(state: State) -> Tuple[Artist, Artist]: + searcher_quiver.set_offsets(state.searchers.pos) + searcher_quiver.set_UVC( + jnp.cos(state.searchers.heading), jnp.sin(state.searchers.heading) + ) + target_colors = self.target_colors[state.targets.found.astype(jnp.int32)] + target_scatter.set_offsets(state.targets.pos) + target_scatter.set_color(target_colors) + return searcher_quiver, target_scatter + + matplotlib.rc("animation", html="jshtml") + self._animation = matplotlib.animation.FuncAnimation( + fig, + make_frame, + frames=states, + interval=interval, + blit=False, + ) + + if save_path: + self._animation.save(save_path) + + return self._animation + + def close(self) -> None: + """Perform any necessary cleanup. + + Environments will automatically :meth:`close()` themselves when + garbage collected or when the program exits. + """ + plt.close(self._figure_name) + + def _draw(self, ax: plt.Axes, state: State) -> None: + ax.clear() + draw_agents(ax, state.searchers, self.searcher_color) + target_colors = self.target_colors[state.targets.found.astype(jnp.int32)] + ax.scatter( + state.targets.pos[:, 0], state.targets.pos[:, 1], marker="o", color=target_colors + ) + + def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: + exists = plt.fignum_exists(self._figure_name) + if exists: + fig = plt.figure(self._figure_name) + ax = fig.get_axes()[0] + else: + fig = plt.figure(self._figure_name, figsize=self._figure_size) + fig.set_layout_engine(layout=TightLayoutEngine(pad=False, w_pad=0.0, h_pad=0.0)) + if not plt.isinteractive(): + fig.show() + ax = fig.add_subplot() + + fig, ax = format_plot(fig, ax, self.env_size) + return fig, ax + + def _update_display(self, fig: plt.Figure) -> None: + if plt.isinteractive(): + # Required to update render when using Jupyter Notebook. + fig.canvas.draw() + if jumanji.environments.is_colab(): + plt.show(self._figure_name) + else: + # Required to update render when not using Jupyter Notebook. + fig.canvas.draw_idle() + fig.canvas.flush_events() + + def _clear_display(self) -> None: + if jumanji.environments.is_colab(): + import IPython.display + + IPython.display.clear_output(True) diff --git a/jumanji/training/configs/env/search_and_rescue.yaml b/jumanji/training/configs/env/search_and_rescue.yaml new file mode 100644 index 000000000..11c5b1b16 --- /dev/null +++ b/jumanji/training/configs/env/search_and_rescue.yaml @@ -0,0 +1,24 @@ +name: search_and_rescue +registered_version: SearchAndRescue-v0 + +network: + layers: [128, 128] + +training: + num_epochs: 50 + num_learner_steps_per_epoch: 400 + n_steps: 20 + total_batch_size: 128 + +evaluation: + eval_total_batch_size: 2000 + greedy_eval_total_batch_size: 2000 + +a2c: + normalize_advantage: False + discount_factor: 0.997 + bootstrapping_factor: 0.95 + l_pg: 1.0 + l_td: 1.0 + l_en: 0.01 + learning_rate: 3e-4 diff --git a/jumanji/training/networks/__init__.py b/jumanji/training/networks/__init__.py index 587941bd1..999ec42c8 100644 --- a/jumanji/training/networks/__init__.py +++ b/jumanji/training/networks/__init__.py @@ -80,6 +80,10 @@ make_actor_critic_networks_rubiks_cube, ) from jumanji.training.networks.rubiks_cube.random import make_random_policy_rubiks_cube +from jumanji.training.networks.search_and_rescue.actor_critic import ( + make_actor_critic_search_and_rescue, +) +from jumanji.training.networks.search_and_rescue.random import make_random_policy_search_and_rescue from jumanji.training.networks.sliding_tile_puzzle.actor_critic import ( make_actor_critic_networks_sliding_tile_puzzle, ) diff --git a/jumanji/training/networks/distribution.py b/jumanji/training/networks/distribution.py index 03262136d..cce681314 100644 --- a/jumanji/training/networks/distribution.py +++ b/jumanji/training/networks/distribution.py @@ -21,6 +21,7 @@ import chex import jax import jax.numpy as jnp +from distrax import Normal class Distribution(abc.ABC): @@ -85,3 +86,27 @@ def kl_divergence( # type: ignore[override] probs = jax.nn.softmax(self.logits) log_probs_other = jax.nn.log_softmax(other.logits) return jnp.sum(jnp.where(probs == 0, 0.0, probs * (log_probs - log_probs_other)), axis=-1) + + +class NormalDistribution(Distribution): + """Normal distribution (with log standard deviations).""" + + def __init__(self, means: chex.Array, log_stds: chex.Array): + self.dist = Normal(loc=means, scale=jnp.exp(log_stds)) + + def mode(self) -> chex.Array: + return self.dist.mode() + + def log_prob(self, x: chex.Array) -> chex.Array: + return self.dist.log_prob(x) + + def entropy(self) -> chex.Array: + return self.dist.entropy() + + def kl_divergence( # type: ignore[override] + self, other: NormalDistribution + ) -> chex.Array: + return self.dist.kl_divergence(other) + + def sample(self, seed: chex.PRNGKey) -> chex.Array: + return self.dist.sample(seed=seed) diff --git a/jumanji/training/networks/parametric_distribution.py b/jumanji/training/networks/parametric_distribution.py index 325c17e9b..e153128a0 100644 --- a/jumanji/training/networks/parametric_distribution.py +++ b/jumanji/training/networks/parametric_distribution.py @@ -15,17 +15,22 @@ """Adapted from Brax.""" import abc -from typing import Any +from typing import Any, Tuple import chex import jax.numpy as jnp import numpy as np -from jumanji.training.networks.distribution import CategoricalDistribution, Distribution +from jumanji.training.networks.distribution import ( + CategoricalDistribution, + Distribution, + NormalDistribution, +) from jumanji.training.networks.postprocessor import ( FactorisedActionSpaceReshapeBijector, IdentityBijector, Postprocessor, + TanhBijector, ) @@ -166,10 +171,64 @@ def __init__(self, action_spec_num_values: chex.ArrayNumpy): Args: action_spec_num_values: the dimensions of each of the factors in the action space""" num_actions = int(np.prod(action_spec_num_values)) - posprocessor = FactorisedActionSpaceReshapeBijector( + postprocessor = FactorisedActionSpaceReshapeBijector( action_spec_num_values=action_spec_num_values ) - super().__init__(param_size=num_actions, postprocessor=posprocessor, event_ndims=0) + super().__init__(param_size=num_actions, postprocessor=postprocessor, event_ndims=0) def create_dist(self, parameters: chex.Array) -> CategoricalDistribution: return CategoricalDistribution(logits=parameters) + + +class ContinuousActionSpaceNormalTanhDistribution(ParametricDistribution): + """Normal distribution for continuous action spaces""" + + def __init__(self, n_actions: int, threshold: float = 0.999): + super().__init__( + param_size=n_actions, + postprocessor=TanhBijector(), + event_ndims=1, + ) + self._inverse_threshold = self._postprocessor.inverse(threshold) + self._log_epsilon = jnp.log(1.0 - threshold) + + def create_dist(self, parameters: Tuple[chex.Array, chex.Array]) -> NormalDistribution: + return NormalDistribution(means=parameters[0], log_stds=parameters[1]) + + def log_prob( + self, parameters: Tuple[chex.Array, chex.Array], raw_actions: chex.Array + ) -> chex.Array: + """Compute the log probability of raw actions when transformed""" + dist = self.create_dist(parameters) + + log_prob_left = dist.dist.log_cdf(-self._inverse_threshold) - self._log_epsilon + log_prob_right = ( + dist.dist.log_survival_function(self._inverse_threshold) - self._log_epsilon + ) + + clipped_actions = jnp.clip(raw_actions, -self._inverse_threshold, self._inverse_threshold) + raw_log_probs = dist.log_prob(clipped_actions) + raw_log_probs -= self._postprocessor.forward_log_det_jacobian(clipped_actions) + + log_probs = jnp.where( + raw_actions <= -self._inverse_threshold, + log_prob_left, + jnp.where( + raw_actions >= self._inverse_threshold, + log_prob_right, + raw_log_probs, + ), + ) + # Sum over non-batch axes + log_probs = jnp.sum(log_probs, axis=tuple(range(1, log_probs.ndim))) + + return log_probs + + def entropy(self, parameters: Tuple[chex.Array, chex.Array], seed: chex.PRNGKey) -> chex.Array: + """Return the entropy of the given distribution.""" + dist = self.create_dist(parameters) + entropy = dist.entropy() + entropy += self._postprocessor.forward_log_det_jacobian(dist.sample(seed=seed)) + # Sum over non-batch axes + entropy = jnp.sum(entropy, axis=tuple(range(1, entropy.ndim))) + return entropy diff --git a/jumanji/training/networks/postprocessor.py b/jumanji/training/networks/postprocessor.py index e97fe3a19..20ba59421 100644 --- a/jumanji/training/networks/postprocessor.py +++ b/jumanji/training/networks/postprocessor.py @@ -17,6 +17,7 @@ import abc import chex +import distrax import jax.numpy as jnp @@ -47,6 +48,23 @@ def forward_log_det_jacobian(self, x: chex.Array) -> chex.Array: return jnp.zeros_like(x, x.dtype) +class TanhBijector(Postprocessor): + """Tanh Bijector for continuous actions.""" + + def __init__(self) -> None: + super().__init__() + self._tanh = distrax.Tanh() + + def forward(self, x: chex.Array) -> chex.Array: + return self._tanh.forward(x) + + def inverse(self, y: chex.Array) -> chex.Array: + return self._tanh.inverse(y) + + def forward_log_det_jacobian(self, x: chex.Array) -> chex.Array: + return self._tanh.forward_log_det_jacobian(x) + + class FactorisedActionSpaceReshapeBijector(Postprocessor): """Identity bijector that reshapes (flattens and unflattens) a sequential action.""" diff --git a/jumanji/training/networks/search_and_rescue/__init__.py b/jumanji/training/networks/search_and_rescue/__init__.py new file mode 100644 index 000000000..21db9ec1c --- /dev/null +++ b/jumanji/training/networks/search_and_rescue/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jumanji/training/networks/search_and_rescue/actor_critic.py b/jumanji/training/networks/search_and_rescue/actor_critic.py new file mode 100644 index 000000000..93dd0f40b --- /dev/null +++ b/jumanji/training/networks/search_and_rescue/actor_critic.py @@ -0,0 +1,116 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from math import prod +from typing import Sequence, Tuple, Union + +import chex +import haiku as hk +import jax +import jax.numpy as jnp + +from jumanji.environments.swarms.search_and_rescue import SearchAndRescue +from jumanji.environments.swarms.search_and_rescue.types import Observation +from jumanji.training.networks.actor_critic import ( + ActorCriticNetworks, + FeedForwardNetwork, +) +from jumanji.training.networks.parametric_distribution import ( + ContinuousActionSpaceNormalTanhDistribution, +) + + +def make_actor_critic_search_and_rescue( + search_and_rescue: SearchAndRescue, + layers: Sequence[int], +) -> ActorCriticNetworks: + """ + Initialise networks for the search-and-rescue network. + + Note: This network is intended to accept environment observations + with agent views of shape [n-agents, n-view], but then + Returns a flattened array of actions for each agent (these + are reshaped by the wrapped environment). + + Args: + search_and_rescue: `SearchAndRescue` environment. + layers: List of hidden layer dimensions. + + Returns: + Continuous action space MLP action and critic networks. + """ + n_actions = prod(search_and_rescue.action_spec.shape) + parametric_action_distribution = ContinuousActionSpaceNormalTanhDistribution(n_actions) + policy_network = make_actor_network( + layers=layers, n_agents=search_and_rescue.generator.num_searchers, n_actions=n_actions + ) + value_network = make_critic_network(layers=layers) + + return ActorCriticNetworks( + policy_network=policy_network, + value_network=value_network, + parametric_action_distribution=parametric_action_distribution, + ) + + +def embedding(x: chex.Array) -> chex.Array: + n_channels = x.shape[-2] + x = hk.Conv1D(2 * n_channels, 3, data_format="NCW")(x) + x = jax.nn.relu(x) + x = hk.MaxPool(2, 2, "SAME", channel_axis=-2)(x) + return x + + +def make_critic_network(layers: Sequence[int]) -> FeedForwardNetwork: + # Shape names: + # B: batch size + # N: number of agents + # C: Observation channels + # O: observation size + + def network_fn(observation: Observation) -> Union[chex.Array, Tuple[chex.Array, chex.Array]]: + x = observation.searcher_views # (B, N, C, O) + x = hk.vmap(embedding, split_rng=False)(x) + x = hk.Flatten()(x) + value = hk.nets.MLP([*layers, 1])(x) # (B,) + return jnp.squeeze(value, axis=-1) + + init, apply = hk.without_apply_rng(hk.transform(network_fn)) + return FeedForwardNetwork(init=init, apply=apply) + + +def make_actor_network(layers: Sequence[int], n_agents: int, n_actions: int) -> FeedForwardNetwork: + # Shape names: + # B: batch size + # N: number of agents + # C: Observation channels + # O: observation size + # A: Number of actions + + def log_std_params(x: chex.Array) -> chex.Array: + return hk.get_parameter("log_stds", shape=x.shape, init=hk.initializers.Constant(0.1)) + + def network_fn(observation: Observation) -> Union[chex.Array, Tuple[chex.Array, chex.Array]]: + x = observation.searcher_views # (B, N, C, O) + x = hk.vmap(embedding, split_rng=False)(x) + x = hk.Flatten()(x) + + means = hk.nets.MLP([*layers, n_agents * n_actions])(x) # (B, N * A) + means = hk.Reshape(output_shape=(n_agents, n_actions))(means) # (B, N, A) + + log_stds = hk.vmap(log_std_params, split_rng=False)(means) # (B, N, A) + + return means, log_stds + + init, apply = hk.without_apply_rng(hk.transform(network_fn)) + return FeedForwardNetwork(init=init, apply=apply) diff --git a/jumanji/training/networks/search_and_rescue/random.py b/jumanji/training/networks/search_and_rescue/random.py new file mode 100644 index 000000000..8f5741c0b --- /dev/null +++ b/jumanji/training/networks/search_and_rescue/random.py @@ -0,0 +1,51 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import chex +import jax.random + +from jumanji.environments.swarms.search_and_rescue import SearchAndRescue +from jumanji.environments.swarms.search_and_rescue.types import Observation +from jumanji.training.networks.protocols import RandomPolicy + + +class SearchAndRescueRandomPolicy(RandomPolicy): + def __init__(self, n_actions: int): + self.n_actions = n_actions + + def __call__( + self, + observation: Observation, + key: chex.PRNGKey, + ) -> chex.Array: + """A random policy given an environment-specific observation. + + Args: + observation: environment observation. + key: random key for action selection. + + Returns: + action + """ + shape = ( + observation.searcher_views.shape[0], + observation.searcher_views.shape[1], + self.n_actions, + ) + return jax.random.uniform(key, shape, minval=-1.0, maxval=1.0) + + +def make_random_policy_search_and_rescue(search_and_rescue: SearchAndRescue) -> RandomPolicy: + """Make random policy for Search & Rescue.""" + return SearchAndRescueRandomPolicy(search_and_rescue.action_spec.shape[-1]) diff --git a/jumanji/training/setup_train.py b/jumanji/training/setup_train.py index 7daa0201d..b5fbdca1c 100644 --- a/jumanji/training/setup_train.py +++ b/jumanji/training/setup_train.py @@ -41,6 +41,7 @@ PacMan, RobotWarehouse, RubiksCube, + SearchAndRescue, SlidingTilePuzzle, Snake, Sokoban, @@ -90,7 +91,7 @@ def setup_logger(cfg: DictConfig) -> Logger: def _make_raw_env(cfg: DictConfig) -> Environment: env = jumanji.make(cfg.env.registered_version) - if cfg.env.name in {"lbf", "connector"}: + if cfg.env.name in {"lbf", "connector", "search_and_rescue"}: # Convert a multi-agent environment to a single-agent environment env = MultiToSingleWrapper(env) return env @@ -206,6 +207,11 @@ def _setup_random_policy(cfg: DictConfig, env: Environment) -> RandomPolicy: elif cfg.env.name == "lbf": assert isinstance(env.unwrapped, LevelBasedForaging) random_policy = networks.make_random_policy_lbf() + elif cfg.env.name == "search_and_rescue": + assert isinstance(env.unwrapped, SearchAndRescue) + random_policy = networks.make_random_policy_search_and_rescue( + search_and_rescue=env.unwrapped + ) else: raise ValueError(f"Environment name not found. Got {cfg.env.name}.") return random_policy @@ -421,6 +427,12 @@ def _setup_actor_critic_neworks(cfg: DictConfig, env: Environment) -> ActorCriti transformer_key_size=cfg.env.network.transformer_key_size, transformer_mlp_units=cfg.env.network.transformer_mlp_units, ) + elif cfg.env.name == "search_and_rescue": + assert isinstance(env.unwrapped, SearchAndRescue) + actor_critic_networks = networks.make_actor_critic_search_and_rescue( + search_and_rescue=env.unwrapped, + layers=cfg.env.network.layers, + ) else: raise ValueError(f"Environment name not found. Got {cfg.env.name}.") return actor_critic_networks diff --git a/mkdocs.yml b/mkdocs.yml index 096f7f8b2..8e2ca2f87 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -43,6 +43,8 @@ nav: - Sokoban: environments/sokoban.md - Snake: environments/snake.md - TSP: environments/tsp.md + - Swarms: + - SearchAndRescue: environments/search_and_rescue.md - User Guides: - Advanced Usage: guides/advanced_usage.md - Registration: guides/registration.md @@ -77,6 +79,8 @@ nav: - Sokoban: api/environments/sokoban.md - Snake: api/environments/snake.md - TSP: api/environments/tsp.md + - Swarms: + - SearchAndRescue: api/environments/search_and_rescue.md - Wrappers: api/wrappers.md - Types: api/types.md diff --git a/pyproject.toml b/pyproject.toml index ea4fabe5c..f637dca31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,7 @@ module = [ "requests.*", "pkg_resources.*", "PIL.*", + "distrax.*", ] ignore_missing_imports = true diff --git a/requirements/requirements-train.txt b/requirements/requirements-train.txt index 890d0f983..55c0032e0 100644 --- a/requirements/requirements-train.txt +++ b/requirements/requirements-train.txt @@ -1,3 +1,4 @@ +distrax>=0.1.5 dm-haiku hydra-core==1.3 neptune-client==0.16.15 diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 0047cfda2..b6bb622bd 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,8 +1,9 @@ chex>=0.1.3 dm-env>=1.5 +esquilax>=2.0.0 gymnasium>=1.0 huggingface-hub -jax>=0.2.26,<0.4.36 +jax>=0.2.26 matplotlib~=3.7.4 numpy>=1.19.5 Pillow>=9.0.0