From 3597d9e407c9f6fea93fb98dcd3ef6bad12c481e Mon Sep 17 00:00:00 2001 From: zombie-einstein <13398815+zombie-einstein@users.noreply.github.com> Date: Mon, 4 Nov 2024 10:16:57 +0000 Subject: [PATCH 01/19] feat: Implement predator prey env (#1) * Initial prototype * feat: Add environment tests * fix: Update esquilax version to fix type issues * docs: Add docstrings * docs: Add docstrings * test: Test multiple reward types * test: Add smoke tests and add max-steps check * feat: Implement pred-prey environment viewer * refactor: Pull out common viewer functionality * test: Add reward and view tests * test: Add rendering tests and add test docstrings * docs: Add predator-prey environment documentation page * docs: Cleanup docstrings * docs: Cleanup docstrings --- docs/environments/predator_prey.md | 53 ++ jumanji/environments/__init__.py | 1 + jumanji/environments/swarms/__init__.py | 13 + .../environments/swarms/common/__init__.py | 13 + jumanji/environments/swarms/common/types.py | 53 ++ jumanji/environments/swarms/common/updates.py | 183 ++++++ jumanji/environments/swarms/common/viewer.py | 73 +++ .../swarms/predator_prey/__init__.py | 14 + .../environments/swarms/predator_prey/env.py | 551 ++++++++++++++++++ .../swarms/predator_prey/env_test.py | 379 ++++++++++++ .../swarms/predator_prey/types.py | 71 +++ .../swarms/predator_prey/updates.py | 129 ++++ .../swarms/predator_prey/viewer.py | 156 +++++ mkdocs.yml | 2 + requirements/requirements.txt | 1 + 15 files changed, 1692 insertions(+) create mode 100644 docs/environments/predator_prey.md create mode 100644 jumanji/environments/swarms/__init__.py create mode 100644 jumanji/environments/swarms/common/__init__.py create mode 100644 jumanji/environments/swarms/common/types.py create mode 100644 jumanji/environments/swarms/common/updates.py create mode 100644 jumanji/environments/swarms/common/viewer.py create mode 100644 jumanji/environments/swarms/predator_prey/__init__.py create mode 100644 jumanji/environments/swarms/predator_prey/env.py create mode 100644 jumanji/environments/swarms/predator_prey/env_test.py create mode 100644 jumanji/environments/swarms/predator_prey/types.py create mode 100644 jumanji/environments/swarms/predator_prey/updates.py create mode 100644 jumanji/environments/swarms/predator_prey/viewer.py diff --git a/docs/environments/predator_prey.md b/docs/environments/predator_prey.md new file mode 100644 index 000000000..bfd8f96c2 --- /dev/null +++ b/docs/environments/predator_prey.md @@ -0,0 +1,53 @@ +# Predator-Prey Flock Environment + +[//]: # (TODO: Add animated plot) + +Environment modelling two competing flocks/swarms of agents: + +- Predator agents are rewarded for contacting prey agents, or for proximity to prey agents. +- Prey agents are conversely penalised for being contacted by, or for proximity to predators. + +Each set of agents can consist of multiple agents, each independently +updated, and with their own independent observations. The agents occupy a square +space with periodic boundary conditions. Agents have a limited view range, i.e. they +only partially observe their local environment (and the locations of neighbouring agents within +range). Rewards are also assigned individually to each agent dependent on their local state. + +## Observation + +Each agent generates an independent observation, an array of values +representing the distance along a ray from the agent to the nearest neighbour, with +each cell representing a ray angle (with `num_vision` rays evenly distributed over the agents +field of vision). Prey and prey agent types are visualised independently to allow agents +to observe both local position and type. + +- `predators`: jax array (float) of shape `(num_predators, 2 * num_vision)` in the unit interval. +- `prey`: jax array (float) of shape `(num_prey, 2 * num_vision)` in the unit interval. + +## Action + +Agents can update their velocity each step by rotating and accelerating/decelerating. Values +are clipped to the range `[-1, 1]` and then scaled by max rotation and acceleration +parameters. Agents are restricted to velocities within a fixed range of speeds. + +- `predators`: jax array (float) of shape (num_predators, 2) each corresponding to `[rotation, acceleration]`. +- `prey`: jax array (float) of shape (num_prey, 2) each corresponding to `[rotation, acceleration]`. + +## Reward + +Rewards can be either sparse or proximity-based. + +### Sparse + +- `predators`: jax array (float) of shape `(num_predators,)`, predators are rewarded a fixed amount + for coming into contact with a prey agent. If they are in contact with multiple prey, only the + nearest agent is selected. +- `prey`: jax array (float) of shape `(num_predators,)`, prey are penalised a fix negative amount if + they come into contact with a predator agent. + +### Proximity + +- `predators`: jax array (float) of shape `(num_predators,)`, predators are rewarded with an amount + scaled linearly with the distance to the prey agents, summed over agents in range. +- `prey`: jax array (float) of shape `(num_predators,)`, prey are penalised by an amount scaled linearly + with distance from predator agents, summed over all predators in range. diff --git a/jumanji/environments/__init__.py b/jumanji/environments/__init__.py index d69fbbf8e..7fe39af30 100644 --- a/jumanji/environments/__init__.py +++ b/jumanji/environments/__init__.py @@ -58,6 +58,7 @@ from jumanji.environments.routing.snake.env import Snake from jumanji.environments.routing.sokoban.env import Sokoban from jumanji.environments.routing.tsp.env import TSP +from jumanji.environments.swarms.predator_prey import PredatorPrey def is_colab() -> bool: diff --git a/jumanji/environments/swarms/__init__.py b/jumanji/environments/swarms/__init__.py new file mode 100644 index 000000000..21db9ec1c --- /dev/null +++ b/jumanji/environments/swarms/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jumanji/environments/swarms/common/__init__.py b/jumanji/environments/swarms/common/__init__.py new file mode 100644 index 000000000..21db9ec1c --- /dev/null +++ b/jumanji/environments/swarms/common/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jumanji/environments/swarms/common/types.py b/jumanji/environments/swarms/common/types.py new file mode 100644 index 000000000..82ec12c87 --- /dev/null +++ b/jumanji/environments/swarms/common/types.py @@ -0,0 +1,53 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dataclasses import dataclass +else: + from chex import dataclass + +import chex + + +@dataclass +class AgentParams: + """ + max_rotate: Max angle an agent can rotate during a step (a fraction of pi) + max_accelerate: Max change in speed during a step + min_speed: Minimum agent speed + max_speed: Maximum agent speed + view_angle: Agent view angle, as a fraction of pi either side of its heading + """ + + max_rotate: float + max_accelerate: float + min_speed: float + max_speed: float + view_angle: float + + +@dataclass +class AgentState: + """ + State of multiple agents of a single type + + pos: 2d position of the (centre of the) agents + heading: Heading of the agents (in radians) + speed: Speed of the agents + """ + + pos: chex.Array + heading: chex.Array + speed: chex.Array diff --git a/jumanji/environments/swarms/common/updates.py b/jumanji/environments/swarms/common/updates.py new file mode 100644 index 000000000..e10f7d733 --- /dev/null +++ b/jumanji/environments/swarms/common/updates.py @@ -0,0 +1,183 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import chex +import esquilax +import jax +import jax.numpy as jnp + +from . import types +from .types import AgentParams + + +@esquilax.transforms.amap +def update_velocity( + _: chex.PRNGKey, + params: types.AgentParams, + x: Tuple[chex.Array, types.AgentState], +) -> Tuple[float, float]: + """ + Get the updated agent heading and speeds from actions + + Args: + _: Dummy JAX random key. + params: Agent parameters. + x: Agent rotation and acceleration actions. + + Returns: + float: New agent heading. + float: New agent speed. + """ + actions, boid = x + rotation = actions[0] * params.max_rotate * jnp.pi + acceleration = actions[1] * params.max_accelerate + + new_heading = (boid.heading + rotation) % (2 * jnp.pi) + new_speeds = jnp.clip( + boid.speed + acceleration, + min=params.min_speed, + max=params.max_speed, + ) + + return new_heading, new_speeds + + +@esquilax.transforms.amap +def move( + _: chex.PRNGKey, _params: None, x: Tuple[chex.Array, float, float] +) -> chex.Array: + """ + Get updated agent positions from current speed and heading + + Args: + _: Dummy JAX random key. + _params: unused parameters. + x: Tuple containing current agent position, heading and speed. + + Returns: + jax array (float32): Updated agent position + """ + pos, heading, speed = x + d_pos = jnp.array([speed * jnp.cos(heading), speed * jnp.sin(heading)]) + return (pos + d_pos) % 1.0 + + +def init_state( + n: int, params: types.AgentParams, key: chex.PRNGKey +) -> types.AgentState: + """ + Randomly initialise state of a group of agents + + Args: + n: Number of agents to initialise. + params: Agent parameters. + key: JAX random key. + + Returns: + AgentState: Random agent states (i.e. position, headings, and speeds) + """ + k1, k2, k3 = jax.random.split(key, 3) + + positions = jax.random.uniform(k1, (n, 2)) + speeds = jax.random.uniform( + k2, (n,), minval=params.min_speed, maxval=params.max_speed + ) + headings = jax.random.uniform(k3, (n,), minval=0.0, maxval=2.0 * jax.numpy.pi) + + return types.AgentState( + pos=positions, + speed=speeds, + heading=headings, + ) + + +def update_state( + key: chex.PRNGKey, params: AgentParams, state: types.AgentState, actions: chex.Array +) -> types.AgentState: + """ + Update the state of a group of agents from a sample of actions + + Args: + key: Dummy JAX random key. + params: Agent parameters. + state: Current agent states. + actions: Agent actions, i.e. a 2D array of action for each agent. + + Returns: + AgentState: Updated state of the agents after applying steering + actions and updating positions. + """ + actions = jax.numpy.clip(actions, min=-1.0, max=1.0) + headings, speeds = update_velocity(key, params, (actions, state)) + positions = move(key, None, (state.pos, headings, speeds)) + + return types.AgentState( + pos=positions, + speed=speeds, + heading=headings, + ) + + +def view( + _: chex.PRNGKey, + params: Tuple[float, float], + a: types.AgentState, + b: types.AgentState, + *, + n_view: int, + i_range: float, +) -> chex.Array: + """ + Simple agent view model + + Simple view model where the agents view angle is subdivided + into an array of values representing the distance from + the agent along a rays from the agent, with rays evenly distributed. + across the agents field of view. The limit of vision is set at 1.0, + which is also the default value if no object is within range. + Currently, this model assumes the viewed objects are circular. + + Args: + _: Dummy JAX random key. + params: Tuple containing agent view angle and view-radius. + a: Viewing agent state. + b: State of agent being viewed. + n_view: Static number of view rays/subdivisions (i.e. how + many cells the resulting array contains). + i_range: Static agent view/interaction range. + + Returns: + jax array (float32): 1D array representing the distance + along a ray from the agent to another agent. + """ + view_angle, radius = params + rays = jnp.linspace( + -view_angle * jnp.pi, + view_angle * jnp.pi, + n_view, + endpoint=True, + ) + dx = esquilax.utils.shortest_vector(a.pos, b.pos) + d = jnp.sqrt(jnp.sum(dx * dx)) / i_range + phi = jnp.arctan2(dx[1], dx[0]) % (2 * jnp.pi) + dh = esquilax.utils.shortest_vector(phi, a.heading, 2 * jnp.pi) + + angular_width = jnp.arctan2(radius, d) + left = dh - angular_width + right = dh + angular_width + + obs = jnp.where(jnp.logical_and(left < rays, rays < right), d, 1.0) + return obs diff --git a/jumanji/environments/swarms/common/viewer.py b/jumanji/environments/swarms/common/viewer.py new file mode 100644 index 000000000..4da841267 --- /dev/null +++ b/jumanji/environments/swarms/common/viewer.py @@ -0,0 +1,73 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import jax.numpy as jnp +from matplotlib.axes import Axes +from matplotlib.figure import Figure +from matplotlib.quiver import Quiver + +from .types import AgentState + + +def draw_agents(ax: Axes, agent_states: AgentState, color: str) -> Quiver: + """Draw a flock/swarm of agent using a matplotlib quiver + + Args: + ax: Plot axes. + agent_states: Flock/swarm agent states. + color: Fill color of agents. + + Returns: + `Quiver`: Matplotlib quiver, can also be used and + updated when animating. + """ + q = ax.quiver( + agent_states.pos[:, 0], + agent_states.pos[:, 1], + jnp.cos(agent_states.heading), + jnp.sin(agent_states.heading), + color=color, + pivot="middle", + ) + return q + + +def format_plot(fig: Figure, ax: Axes, border: float = 0.01) -> Tuple[Figure, Axes]: + """Format a flock/swarm plot, remove ticks and bound to the unit interval + + Args: + fig: Matplotlib figure. + ax: Matplotlib axes. + border: Border padding to apply around plot. + + Returns: + Figure: Formatted figure. + Axes: Formatted axes. + """ + fig.subplots_adjust( + top=1.0 - border, + bottom=border, + right=1.0 - border, + left=border, + hspace=0, + wspace=0, + ) + ax.set_xticks([]) + ax.set_yticks([]) + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + + return fig, ax diff --git a/jumanji/environments/swarms/predator_prey/__init__.py b/jumanji/environments/swarms/predator_prey/__init__.py new file mode 100644 index 000000000..4fef030a7 --- /dev/null +++ b/jumanji/environments/swarms/predator_prey/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .env import PredatorPrey diff --git a/jumanji/environments/swarms/predator_prey/env.py b/jumanji/environments/swarms/predator_prey/env.py new file mode 100644 index 000000000..a905e4569 --- /dev/null +++ b/jumanji/environments/swarms/predator_prey/env.py @@ -0,0 +1,551 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import cached_property +from typing import Optional, Tuple + +import chex +import jax +import jax.numpy as jnp +from esquilax.transforms import nearest_neighbour, spatial + +from jumanji import specs +from jumanji.env import Environment +from jumanji.environments.swarms.common.types import AgentParams +from jumanji.environments.swarms.common.updates import init_state, update_state, view +from jumanji.types import TimeStep, restart, termination, transition +from jumanji.viewer import Viewer + +from .types import Actions, Observation, Rewards, State +from .updates import ( + distance_predator_rewards, + distance_prey_rewards, + sparse_predator_rewards, + sparse_prey_rewards, +) +from .viewer import PredatorPreyViewer + + +class PredatorPrey(Environment): + """A predator and prey flock environment + + Environment modelling two swarms of agent types, predators + who are rewarded for avoiding pre agents, and conversely + prey agent who are rewarded for touching/catching + prey agents. Both agent types can consist of a large + number of individual agents, each with individual (local) + observations, actions, and rewards. Agents interact + on a uniform space with wrapped boundaries. + + - observation: `Observation` + Arrays representing each agent's local view of the environment. + Each cell of the array represent the distance from the agent + two the nearest other agents in the environment. Each agent type + is observed independently. + + - predators: jax array (float) of shape (num_predators, 2 * num_vision) + - prey: jax array (float) of shape (num_prey, 2 * num_vision) + + - action: `Actions` + Arrays of individual agent actions. Each agents actions rotate and + accelerate/decelerate the agent as [rotation, acceleration] on the range + [-1, 1]. These values are then scaled to update agent velocities within + given parameters. + + - predators: jax array (float) of shape (num_predators, 2) + - prey: jax array (float) of shape (num_prey, 2) + + - reward: `Rewards` + Arrays of individual agent rewards. Rewards can either be generated + sparsely applied when agent collide, or be generated based on distance + to other agents (hence they are dependent on the number and density + of agents). + + - predators: jax array (float) of shape (num_predators,) + - prey: jax array (float) of shape (num_prey,) + + - state: `State` + - predators: `AgentState` + - pos: jax array (float) of shape (num_predators, 2) in the range [0, 1]. + - heading: jax array (float) of shape (num_predators,) in + the range [0, 2pi]. + - speed: jax array (float) of shape (num_predators,) in the + range [min_speed, max_speed]. + - prey: `AgentState` + - pos: jax array (float) of shape (num_prey, 2) in the range [0, 1]. + - heading: jax array (float) of shape (num_prey,) in + the range [0, 2pi]. + - speed: jax array (float) of shape (num_prey,) in the + range [min_speed, max_speed]. + - key: jax array (uint32) of shape (2,) + - step: int representing the current simulation step. + + + ```python + from jumanji.environments import PredatorPrey + env = PredatorPrey( + num_predators=2, + num_prey=10, + prey_vision_range=0.1, + predator_vision_range=0.1, + num_vision=10, + agent_radius=0.01, + sparse_rewards=True, + prey_penalty=0.1, + predator_rewards=0.2, + predator_max_rotate=0.1, + predator_max_accelerate=0.01, + predator_min_speed=0.01, + predator_max_speed=0.05, + predator_view_angle=0.5, + prey_max_rotate=0.1, + prey_max_accelerate=0.01, + prey_min_speed=0.01, + prey_max_speed=0.05, + prey_view_angle=0.5, + ) + key = jax.random.PRNGKey(0) + state, timestep = jax.jit(env.reset)(key) + env.render(state) + action = env.action_spec.generate_value() + state, timestep = jax.jit(env.step)(state, action) + env.render(state) + ``` + """ + + def __init__( + self, + num_predators: int, + num_prey: int, + prey_vision_range: float, + predator_vision_range: float, + num_vision: int, + agent_radius: float, + sparse_rewards: bool, + prey_penalty: float, + predator_rewards: float, + predator_max_rotate: float, + predator_max_accelerate: float, + predator_min_speed: float, + predator_max_speed: float, + predator_view_angle: float, + prey_max_rotate: float, + prey_max_accelerate: float, + prey_min_speed: float, + prey_max_speed: float, + prey_view_angle: float, + max_steps: int = 10_000, + viewer: Optional[Viewer[State]] = None, + ) -> None: + """Instantiates a `PredatorPrey` environment + + Note: + The environment is square with dimensions + `[1.0, 1.0]` so parameters should be scaled + appropriately. Also note that performance is + dependent on agent vision and interaction ranges, + where larger values can lead to large number of + agent interactions. + + Args: + num_predators: Number of predator agents. + num_prey: Number of prey agents. + prey_vision_range: Prey agent vision range. + predator_vision_range: Predator agent vision range. + num_vision: Number of cells/subdivisions in agent + view models. Larger numbers provide a more accurate + view, at the cost of the environment, at the cost + of performance and memory usage. + agent_radius: Radius of individual agents. This + effects both agent collision range and how + large they appear to other agents. + sparse_rewards: If `True` fix rewards will be applied + when agents are within a fixed collision range. If + `False` rewards are dependent on distance to + other agents with vision range. + prey_penalty: Penalty to apply to prey agents if + they interact with predator agents. This + value is negated when applied. + predator_rewards: Rewards provided to predator agents + when they interact with prey agents. + predator_max_rotate: Maximum rotation predator agents can + turn within a step. Should be a value from [0,1] + representing a fraction of pi radians. + predator_max_accelerate: Maximum acceleration/deceleration + a predator agent can apply within a step. + predator_min_speed: Minimum speed a predator agent can move at. + predator_max_speed: Maximum speed a predator agent can move at. + predator_view_angle: Predator agent local view angle. Should be + a value from [0,1] representing a fraction of pi radians. + The view cone of an agent goes from +- of the view angle + relative to its heading. + prey_max_rotate: Maximum rotation prey agents can + turn within a step. Should be a value from [0,1] + representing a fraction of pi radians. + prey_max_accelerate: Maximum acceleration/deceleration + a prey agent can apply within a step. + prey_min_speed: Minimum speed a prey agent can move at. + prey_max_speed: Maximum speed a prey agent can move at. + prey_view_angle: Prey agent local view angle. Should be + a value from [0,1] representing a fraction of pi radians. + The view cone of an agent goes from +- of the view angle + relative to its heading. + max_steps: Maximum number of environment steps before termination + viewer: `Viewer` used for rendering. Defaults to `PredatorPreyViewer`. + """ + self.num_predators = num_predators + self.num_prey = num_prey + self.prey_vision_range = prey_vision_range + self.predator_vision_range = predator_vision_range + self.num_vision = num_vision + self.agent_radius = agent_radius + self.sparse_rewards = sparse_rewards + self.prey_penalty = prey_penalty + self.predator_rewards = predator_rewards + self.predator_params = AgentParams( + max_rotate=predator_max_rotate, + max_accelerate=predator_max_accelerate, + min_speed=predator_min_speed, + max_speed=predator_max_speed, + view_angle=predator_view_angle, + ) + self.prey_params = AgentParams( + max_rotate=prey_max_rotate, + max_accelerate=prey_max_accelerate, + min_speed=prey_min_speed, + max_speed=prey_max_speed, + view_angle=prey_view_angle, + ) + self.max_steps = max_steps + super().__init__() + self._viewer = viewer or PredatorPreyViewer() + + def __repr__(self) -> str: + return "\n".join( + [ + "Predator-prey flock environment:", + f" - num predators: {self.num_predators}", + f" - num prey: {self.num_prey}", + f" - prey vision range: {self.prey_vision_range}", + f" - predator vision range: {self.predator_vision_range}" + f" - num vision: {self.num_vision}" + f" - agent radius: {self.agent_radius}" + f" - sparse-rewards: {self.sparse_rewards}", + f" - prey-penalty: {self.prey_penalty}", + f" - predator-rewards: {self.predator_rewards}", + ] + ) + + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: + """Randomly initialise predator and prey positions and velocities. + + Args: + key: Random key used to reset the environment. + + Returns: + state: Agent states. + timestep: TimeStep with individual agent local environment views. + """ + key, predator_key, prey_key = jax.random.split(key, num=3) + predator_state = init_state( + self.num_predators, self.predator_params, predator_key + ) + prey_state = init_state(self.num_prey, self.prey_params, prey_key) + state = State(predators=predator_state, prey=prey_state, key=key) + timestep = restart(observation=self._state_to_observation(state)) + return state, timestep + + def step( + self, state: State, action: Actions + ) -> Tuple[State, TimeStep[Observation]]: + """Environment update + + Update agent velocities and consequently their positions, + them generate new local views and rewards. + + Args: + state: Agent states. + action: Arrays of predator and prey individual actions. + + Returns: + state: Updated agent positions and velocities. + timestep: Transition timestep with individual agent local observations. + """ + predators = update_state( + state.key, self.predator_params, state.predators, action.predators + ) + prey = update_state(state.key, self.prey_params, state.prey, action.prey) + + state = State( + predators=predators, prey=prey, key=state.key, step=state.step + 1 + ) + + if self.sparse_rewards: + rewards = self._state_to_sparse_rewards(state) + else: + rewards = self._state_to_distance_rewards(state) + + observation = self._state_to_observation(state) + timestep = jax.lax.cond( + state.step >= self.max_steps, + termination, + transition, + rewards, + observation, + ) + return state, timestep + + def _state_to_observation(self, state: State) -> Observation: + + prey_obs_predators = spatial( + view, + reduction=jnp.minimum, + default=jnp.ones((self.num_vision,)), + include_self=False, + i_range=self.prey_vision_range, + )( + state.key, + (self.prey_params.view_angle, self.agent_radius), + state.prey, + state.predators, + pos=state.prey.pos, + pos_b=state.predators.pos, + n_view=self.num_vision, + i_range=self.prey_vision_range, + ) + prey_obs_prey = spatial( + view, + reduction=jnp.minimum, + default=jnp.ones((self.num_vision,)), + include_self=False, + i_range=self.prey_vision_range, + )( + state.key, + (self.predator_params.view_angle, self.agent_radius), + state.prey, + state.prey, + pos=state.prey.pos, + n_view=self.num_vision, + i_range=self.prey_vision_range, + ) + predator_obs_prey = spatial( + view, + reduction=jnp.minimum, + default=jnp.ones((self.num_vision,)), + include_self=False, + i_range=self.predator_vision_range, + )( + state.key, + (self.predator_params.view_angle, self.agent_radius), + state.predators, + state.prey, + pos=state.predators.pos, + pos_b=state.prey.pos, + n_view=self.num_vision, + i_range=self.predator_vision_range, + ) + predator_obs_predator = spatial( + view, + reduction=jnp.minimum, + default=jnp.ones((self.num_vision,)), + include_self=False, + i_range=self.predator_vision_range, + )( + state.key, + (self.predator_params.view_angle, self.agent_radius), + state.predators, + state.predators, + pos=state.predators.pos, + n_view=self.num_vision, + i_range=self.predator_vision_range, + ) + + predator_obs = jnp.hstack([predator_obs_prey, predator_obs_predator]) + prey_obs = jnp.hstack([prey_obs_predators, prey_obs_prey]) + + return Observation( + predators=predator_obs, + prey=prey_obs, + ) + + def _state_to_sparse_rewards(self, state: State) -> Rewards: + prey_rewards = spatial( + sparse_prey_rewards, + reduction=jnp.add, + default=0.0, + include_self=False, + i_range=2 * self.agent_radius, + )( + state.key, + self.prey_penalty, + None, + None, + pos=state.prey.pos, + pos_b=state.predators.pos, + ) + predator_rewards = nearest_neighbour( + sparse_predator_rewards, + default=0.0, + i_range=2 * self.agent_radius, + )( + state.key, + self.predator_rewards, + None, + None, + pos=state.predators.pos, + pos_b=state.prey.pos, + ) + return Rewards( + predators=predator_rewards, + prey=prey_rewards, + ) + + def _state_to_distance_rewards(self, state: State) -> Rewards: + + prey_rewards = spatial( + distance_prey_rewards, + reduction=jnp.add, + default=0.0, + include_self=False, + i_range=self.prey_vision_range, + )( + state.key, + self.prey_penalty, + state.prey, + state.predators, + pos=state.prey.pos, + pos_b=state.predators.pos, + i_range=self.prey_vision_range, + ) + predator_rewards = spatial( + distance_predator_rewards, + reduction=jnp.add, + default=0.0, + include_self=False, + i_range=self.predator_vision_range, + )( + state.key, + self.predator_rewards, + state.predators, + state.prey, + pos=state.predators.pos, + pos_b=state.prey.pos, + i_range=self.prey_vision_range, + ) + + return Rewards( + predators=predator_rewards, + prey=prey_rewards, + ) + + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: + """Returns the observation spec. + + Local predator and prey agent views representing + the distance to closest neighbours in the environment. + + Returns: + observation_spec: Predator-prey observation spec + """ + predators = specs.BoundedArray( + shape=(self.num_predators, 2 * self.num_vision), + minimum=0.0, + maximum=1.0, + dtype=float, + name="predators", + ) + prey = specs.BoundedArray( + shape=(self.num_prey, 2 * self.num_vision), + minimum=0.0, + maximum=1.0, + dtype=float, + name="prey", + ) + return specs.Spec( + Observation, + "ObservationSpec", + predators=predators, + prey=prey, + ) + + @cached_property + def action_spec(self) -> specs.Spec[Actions]: + """Returns the action spec. + + Arrays of individual agent actions. Each agents action is + an array representing [rotation, acceleration] in the range + [-1, 1]. + + Returns: + action_spec: Predator-prey action spec + """ + predators = specs.BoundedArray( + shape=(self.num_predators, 2), + minimum=-1.0, + maximum=1.0, + dtype=float, + name="predators", + ) + prey = specs.BoundedArray( + shape=(self.num_prey, 2), + minimum=-1.0, + maximum=1.0, + dtype=float, + name="prey", + ) + return specs.Spec( + Actions, + "ActionSpec", + predators=predators, + prey=prey, + ) + + @cached_property + def reward_spec(self) -> specs.Spec[Rewards]: # type: ignore[override] + """Returns the reward spec + + Individual rewards for predator and prey types. + + Returns: + reward_spec: Predator-prey reward spec + """ + predators = specs.Array( + shape=(self.num_predators,), + dtype=float, + name="predators", + ) + prey = specs.Array( + shape=(self.num_prey,), + dtype=float, + name="prey", + ) + return specs.Spec( + Rewards, + "rewardsSpec", + predators=predators, + prey=prey, + ) + + def render(self, state: State) -> None: + """Render a frame of the environment for a given state using matplotlib. + + Args: + state: State object containing the current dynamics of the environment. + """ + self._viewer.render(state) + + def close(self) -> None: + """Perform any necessary cleanup.""" + self._viewer.close() diff --git a/jumanji/environments/swarms/predator_prey/env_test.py b/jumanji/environments/swarms/predator_prey/env_test.py new file mode 100644 index 000000000..16658b9bd --- /dev/null +++ b/jumanji/environments/swarms/predator_prey/env_test.py @@ -0,0 +1,379 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Tuple + +import chex +import jax +import jax.numpy as jnp +import matplotlib +import matplotlib.pyplot as plt +import py +import pytest + +from jumanji.environments.swarms.common.types import AgentState +from jumanji.environments.swarms.predator_prey import PredatorPrey +from jumanji.environments.swarms.predator_prey.types import ( + Actions, + Observation, + Rewards, + State, +) +from jumanji.testing.env_not_smoke import ( + check_env_does_not_smoke, + check_env_specs_does_not_smoke, +) +from jumanji.types import StepType, TimeStep + +PREDATOR_REWARD = 0.2 +PREY_PENALTY = 0.1 + + +@pytest.fixture +def env() -> PredatorPrey: + return PredatorPrey( + num_predators=2, + num_prey=10, + prey_vision_range=0.1, + predator_vision_range=0.2, + num_vision=11, + agent_radius=0.05, + sparse_rewards=True, + prey_penalty=PREY_PENALTY, + predator_rewards=PREDATOR_REWARD, + predator_max_rotate=0.1, + predator_max_accelerate=0.01, + predator_min_speed=0.01, + predator_max_speed=0.05, + predator_view_angle=0.5, + prey_max_rotate=0.1, + prey_max_accelerate=0.01, + prey_min_speed=0.01, + prey_max_speed=0.05, + prey_view_angle=0.5, + ) + + +def test_env_init(env: PredatorPrey) -> None: + """ + Check newly initialised state has expected array shapes + and initial timestep. + """ + k = jax.random.PRNGKey(101) + state, timestep = env.reset(k) + assert isinstance(state, State) + + assert isinstance(state.predators, AgentState) + assert state.predators.pos.shape == (env.num_predators, 2) + assert state.predators.speed.shape == (env.num_predators,) + assert state.predators.speed.shape == (env.num_predators,) + + assert isinstance(state.prey, AgentState) + assert state.prey.pos.shape == (env.num_prey, 2) + assert state.prey.speed.shape == (env.num_prey,) + assert state.prey.speed.shape == (env.num_prey,) + + assert isinstance(timestep.observation, Observation) + assert timestep.observation.predators.shape == ( + env.num_predators, + 2 * env.num_vision, + ) + assert timestep.observation.prey.shape == (env.num_prey, 2 * env.num_vision) + assert timestep.step_type == StepType.FIRST + + +@pytest.mark.parametrize("sparse_rewards", [True, False]) +def test_env_step(env: PredatorPrey, sparse_rewards: bool) -> None: + """ + Run several steps of the environment with random actions and + check states (i.e. positions, heading, speeds) all fall + inside expected ranges. + """ + env.sparse_rewards = sparse_rewards + key = jax.random.PRNGKey(101) + n_steps = 22 + + def step( + carry: Tuple[chex.PRNGKey, State], _: None + ) -> Tuple[Tuple[chex.PRNGKey, State], Tuple[State, TimeStep[Observation]]]: + k, state = carry + k, k_pred, k_prey = jax.random.split(k, num=3) + actions = Actions( + predators=jax.random.uniform( + k_pred, (env.num_predators, 2), minval=-1.0, maxval=1.0 + ), + prey=jax.random.uniform(k_prey, (env.num_prey, 2), minval=-1.0, maxval=1.0), + ) + new_state, timestep = env.step(state, actions) + return (k, new_state), (state, timestep) + + init_state, _ = env.reset(key) + (_, final_state), (state_history, timesteps) = jax.lax.scan( + step, (key, init_state), length=n_steps + ) + + assert isinstance(state_history, State) + + assert state_history.predators.pos.shape == (n_steps, env.num_predators, 2) + assert jnp.all( + (0.0 <= state_history.predators.pos) & (state_history.predators.pos <= 1.0) + ) + assert state_history.predators.speed.shape == (n_steps, env.num_predators) + assert jnp.all( + (env.predator_params.min_speed <= state_history.predators.speed) + & (state_history.predators.speed <= env.predator_params.max_speed) + ) + assert state_history.predators.speed.shape == (n_steps, env.num_predators) + assert jnp.all( + (0.0 <= state_history.predators.heading) + & (state_history.predators.heading <= 2.0 * jnp.pi) + ) + + assert state_history.prey.pos.shape == (n_steps, env.num_prey, 2) + assert jnp.all((0.0 <= state_history.prey.pos) & (state_history.prey.pos <= 1.0)) + assert state_history.prey.speed.shape == (n_steps, env.num_prey) + assert jnp.all( + (env.predator_params.min_speed <= state_history.prey.speed) + & (state_history.prey.speed <= env.predator_params.max_speed) + ) + assert state_history.prey.heading.shape == (n_steps, env.num_prey) + assert jnp.all( + (0.0 <= state_history.prey.heading) + & (state_history.prey.heading <= 2.0 * jnp.pi) + ) + + +@pytest.mark.parametrize("sparse_rewards", [True, False]) +def test_env_does_not_smoke(env: PredatorPrey, sparse_rewards: bool) -> None: + """Test that we can run an episode without any errors.""" + env.sparse_rewards = sparse_rewards + env.max_steps = 10 + + def select_action(action_key: chex.PRNGKey, _state: Observation) -> Actions: + predator_key, prey_key = jax.random.split(action_key) + return Actions( + predators=jax.random.uniform( + predator_key, (env.num_predators, 2), minval=-1.0, maxval=1.0 + ), + prey=jax.random.uniform( + prey_key, (env.num_prey, 2), minval=-1.0, maxval=1.0 + ), + ) + + check_env_does_not_smoke(env, select_action=select_action) + + +def test_env_specs_do_not_smoke(env: PredatorPrey) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(env) + + +@pytest.mark.parametrize( + "predator_pos, predator_heading, predator_view, prey_pos, prey_heading, prey_view", + [ + # Both out of view range + ([[0.8, 0.5]], [jnp.pi], [(0, 0, 1.0)], [[0.2, 0.5]], [0.0], [(0, 0, 1.0)]), + # In predator range but not prey + ([[0.35, 0.5]], [jnp.pi], [(0, 5, 0.75)], [[0.2, 0.5]], [0.0], [(0, 0, 1.0)]), + # Both view each other + ([[0.25, 0.5]], [jnp.pi], [(0, 5, 0.25)], [[0.2, 0.5]], [0.0], [(0, 5, 0.5)]), + # Prey facing wrong direction + ( + [[0.25, 0.5]], + [jnp.pi], + [(0, 5, 0.25)], + [[0.2, 0.5]], + [jnp.pi], + [(0, 0, 1.0)], + ), + # Prey sees closest predator + ( + [[0.35, 0.5], [0.25, 0.5]], + [jnp.pi, jnp.pi], + [(0, 5, 0.75), (0, 16, 0.5), (1, 5, 0.25)], + [[0.2, 0.5]], + [0.0], + [(0, 5, 0.5)], + ), + # Observed around wrapped edge + ( + [[0.025, 0.5]], + [jnp.pi], + [(0, 5, 0.25)], + [[0.975, 0.5]], + [0.0], + [(0, 5, 0.5)], + ), + ], +) +def test_view_observations( + env: PredatorPrey, + predator_pos: List[List[float]], + predator_heading: List[float], + predator_view: List[Tuple[int, int, float]], + prey_pos: List[List[float]], + prey_heading: List[float], + prey_view: List[Tuple[int, int, float]], +) -> None: + """ + Test view model generates expected array with different + configurations of agents. + """ + + predator_pos = jnp.array(predator_pos) + predator_heading = jnp.array(predator_heading) + predator_speed = jnp.zeros(predator_heading.shape) + + prey_pos = jnp.array(prey_pos) + prey_heading = jnp.array(prey_heading) + prey_speed = jnp.zeros(prey_heading.shape) + + state = State( + predators=AgentState( + pos=predator_pos, heading=predator_heading, speed=predator_speed + ), + prey=AgentState(pos=prey_pos, heading=prey_heading, speed=prey_speed), + key=jax.random.PRNGKey(101), + ) + + obs = env._state_to_observation(state) + + assert isinstance(obs, Observation) + + predator_expected = jnp.ones( + ( + predator_heading.shape[0], + 2 * env.num_vision, + ) + ) + for i, idx, val in predator_view: + predator_expected = predator_expected.at[i, idx].set(val) + + assert jnp.all(jnp.isclose(obs.predators, predator_expected)) + + prey_expected = jnp.ones( + ( + prey_heading.shape[0], + 2 * env.num_vision, + ) + ) + for i, idx, val in prey_view: + prey_expected = prey_expected.at[i, idx].set(val) + + assert jnp.all(jnp.isclose(obs.prey[0], prey_expected)) + + +@pytest.mark.parametrize( + "predator_pos, predator_reward, prey_pos, prey_reward", + [ + ([0.5, 0.5], 0.0, [0.8, 0.5], 0.0), + ([0.5, 0.5], PREDATOR_REWARD, [0.5999, 0.5], -PREY_PENALTY), + ([0.5, 0.5], PREDATOR_REWARD, [0.5001, 0.5], -PREY_PENALTY), + ], +) +def test_sparse_rewards( + env: PredatorPrey, + predator_pos: List[float], + predator_reward: float, + prey_pos: List[float], + prey_reward: float, +) -> None: + """ + Test sparse rewards are correctly assigned. + """ + + state = State( + predators=AgentState( + pos=jnp.array([predator_pos]), + heading=jnp.zeros((1,)), + speed=jnp.zeros((1,)), + ), + prey=AgentState( + pos=jnp.array([prey_pos]), + heading=jnp.zeros((1,)), + speed=jnp.zeros((1,)), + ), + key=jax.random.PRNGKey(101), + ) + + rewards = env._state_to_sparse_rewards(state) + assert isinstance(rewards, Rewards) + assert rewards.predators[0] == predator_reward + assert rewards.prey[0] == prey_reward + + +@pytest.mark.parametrize( + "predator_pos, predator_reward, prey_pos, prey_reward", + [ + ([0.5, 0.5], 0.0, [0.8, 0.5], 0.0), + ([0.5, 0.5], 0.5 * PREDATOR_REWARD, [0.55, 0.5], -0.5 * PREY_PENALTY), + ([0.5, 0.5], PREDATOR_REWARD, [0.5 + 1e-10, 0.5], -PREY_PENALTY), + ], +) +def test_distance_rewards( + env: PredatorPrey, + predator_pos: List[float], + predator_reward: float, + prey_pos: List[float], + prey_reward: float, +) -> None: + """ + Test rewards scaled with distance are correctly assigned. + """ + + state = State( + predators=AgentState( + pos=jnp.array([predator_pos]), + heading=jnp.zeros((1,)), + speed=jnp.zeros((1,)), + ), + prey=AgentState( + pos=jnp.array([prey_pos]), + heading=jnp.zeros((1,)), + speed=jnp.zeros((1,)), + ), + key=jax.random.PRNGKey(101), + ) + + rewards = env._state_to_distance_rewards(state) + assert isinstance(rewards, Rewards) + assert jnp.isclose(rewards.predators[0], predator_reward) + assert jnp.isclose(rewards.prey[0], prey_reward) + + +def test_predator_prey_render( + monkeypatch: pytest.MonkeyPatch, env: PredatorPrey +) -> None: + """Check that the render method builds the figure but does not display it.""" + monkeypatch.setattr(plt, "show", lambda fig: None) + step_fn = jax.jit(env.step) + state, timestep = env.reset(jax.random.PRNGKey(0)) + action = env.action_spec.generate_value() + state, timestep = step_fn(state, action) + env.render(state) + env.close() + + +def test_snake__animation(env: PredatorPrey, tmpdir: py.path.local) -> None: + """Check that the animation method creates the animation correctly and can save to a gif.""" + step_fn = jax.jit(env.step) + state, _ = env.reset(jax.random.PRNGKey(0)) + states = [state] + action = env.action_spec.generate_value() + state, _ = step_fn(state, action) + states.append(state) + animation = env._viewer.animate(states, 200, None) + assert isinstance(animation, matplotlib.animation.Animation) + + path = str(tmpdir.join("/anim.gif")) + animation.save(path, writer=matplotlib.animation.PillowWriter(fps=10), dpi=60) diff --git a/jumanji/environments/swarms/predator_prey/types.py b/jumanji/environments/swarms/predator_prey/types.py new file mode 100644 index 000000000..3b12fd1ea --- /dev/null +++ b/jumanji/environments/swarms/predator_prey/types.py @@ -0,0 +1,71 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dataclasses import dataclass +else: + from chex import dataclass + +import chex + +from jumanji.environments.swarms.common.types import AgentState + + +@dataclass +class State: + """ + predators: Predator agent states. + prey: Prey agent states. + key: JAX random key. + step: Environment step number + """ + + predators: AgentState + prey: AgentState + key: chex.PRNGKey + step: int = 0 + + +@dataclass +class Observation: + """ + predators: Local view of predator agents. + prey: Local view of prey agents. + """ + + predators: chex.Array + prey: chex.Array + + +@dataclass +class Actions: + """ + predators: Array of actions for predator agents. + prey: Array of actions for prey agents. + """ + + predators: chex.Array + prey: chex.Array + + +@dataclass +class Rewards: + """ + predators: Array of individual rewards for predator agents. + prey: Array of individual rewards for prey agents. + """ + + predators: chex.Array + prey: chex.Array diff --git a/jumanji/environments/swarms/predator_prey/updates.py b/jumanji/environments/swarms/predator_prey/updates.py new file mode 100644 index 000000000..75bc536d0 --- /dev/null +++ b/jumanji/environments/swarms/predator_prey/updates.py @@ -0,0 +1,129 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import chex +import esquilax + +from jumanji.environments.swarms.common import types + + +def sparse_prey_rewards( + _k: chex.PRNGKey, + penalty: float, + _prey: Optional[types.AgentState], + _predator: Optional[types.AgentState], +) -> float: + """Penalise a prey agent if contacted by a predator agent. + + Apply a negative penalty to prey agents that collide + with a prey agent. This function is applied using an + Esquilax spatial interaction. + + Args: + _k: Dummy JAX random key. + penalty: Penalty value. + _prey: Optional unused prey agent-state. + _predator: Optional unused predator agent-state. + + Returns: + float: Negative penalty applied to prey agent. + """ + return -penalty + + +def distance_prey_rewards( + _k: chex.PRNGKey, + penalty: float, + prey: types.AgentState, + predator: types.AgentState, + *, + i_range: float, +) -> Union[float, chex.Array]: + """Penalise a prey agent based on distance from a predator agent. + + Apply a negative penalty based on a distance between + agents. The penalty is a linear function of distance, + 0 at max distance up to `-penalty` at 0 distance. This function + can be used with an Esquilax spatial interaction to accumulate + rewards between agents. + + Args: + _k: Dummy JAX random key. + penalty: Maximum penalty applied. + prey: Prey agent-state. + predator: Predator agent-state. + i_range: Static interaction range. + + Returns: + float: Agent rewards. + """ + d = esquilax.utils.shortest_distance(prey.pos, predator.pos) / i_range + return penalty * (d - 1.0) + + +def sparse_predator_rewards( + _k: chex.PRNGKey, + reward: float, + _predator: Optional[types.AgentState], + _prey: Optional[types.AgentState], +) -> float: + """Reward a predator agent if it is within range of a prey agent + + Apply a fixed positive reward if a predator agent is within + a fixed range of a prey-agent. This function can + be used with an Esquilax spatial interaction to + apply rewards to agents in range. + + Args: + _k: Dummy JAX random key. + reward: Reward value to apply. + _predator: Optional unused agent-state. + _prey: Optional unused agent-state. + + Returns: + float: Predator agent reward. + """ + return reward + + +def distance_predator_rewards( + _k: chex.PRNGKey, + reward: float, + predator: types.AgentState, + prey: types.AgentState, + *, + i_range: float, +) -> Union[float, chex.Array]: + """Reward a predator agent based on distance from a prey agent. + + Apply a positive reward based on the linear distance between + a predator and prey agent. Rewards are zero at the max + interaction distance, and maximal at 0 range. This function + can be used with an Esquilax spatial interaction to accumulate + rewards between agents. + + Args: + _k: Dummy JAX random key. + reward: Maximum reward value. + predator: Predator agent-state. + prey: Prey agent-state. + i_range: Static interaction range. + + Returns: + float@ Predator agent reward. + """ + d = esquilax.utils.shortest_distance(predator.pos, prey.pos) / i_range + return reward * (1.0 - d) diff --git a/jumanji/environments/swarms/predator_prey/viewer.py b/jumanji/environments/swarms/predator_prey/viewer.py new file mode 100644 index 000000000..b94750a90 --- /dev/null +++ b/jumanji/environments/swarms/predator_prey/viewer.py @@ -0,0 +1,156 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, Sequence, Tuple + +import jax.numpy as jnp +import matplotlib.animation +import matplotlib.pyplot as plt +from matplotlib.layout_engine import TightLayoutEngine + +import jumanji +import jumanji.environments +from jumanji.environments.swarms.common.viewer import draw_agents, format_plot +from jumanji.environments.swarms.predator_prey.types import State +from jumanji.viewer import Viewer + + +class PredatorPreyViewer(Viewer): + def __init__( + self, + figure_name: str = "PredatorPrey", + figure_size: Tuple[float, float] = (6.0, 6.0), + predator_color: str = "red", + prey_color: str = "green", + ) -> None: + """Viewer for the `PredatorPrey` environment. + + Args: + figure_name: the window name to be used when initialising the window. + figure_size: tuple (height, width) of the matplotlib figure window. + """ + self._figure_name = figure_name + self._figure_size = figure_size + self.predator_color = predator_color + self.prey_color = prey_color + self._animation: Optional[matplotlib.animation.Animation] = None + + def render(self, state: State) -> None: + """Render a frame of the environment for a given state using matplotlib. + + Args: + state: State object containing the current dynamics of the environment. + + """ + self._clear_display() + fig, ax = self._get_fig_ax() + self._draw(ax, state) + self._update_display(fig) + + def animate( + self, states: Sequence[State], interval: int, save_path: Optional[str] + ) -> matplotlib.animation.FuncAnimation: + """Create an animation from a sequence of states. + + Args: + states: sequence of `State` corresponding to subsequent timesteps. + interval: delay between frames in milliseconds, default to 200. + save_path: the path where the animation file should be saved. If it is None, the plot + will not be saved. + + Returns: + Animation object that can be saved as a GIF, MP4, or rendered with HTML. + """ + if not states: + raise ValueError(f"The states argument has to be non-empty, got {states}.") + fig, ax = plt.subplots( + num=f"{self._figure_name}Anim", figsize=self._figure_size + ) + fig, ax = format_plot(fig, ax) + + predators_quiver = draw_agents(ax, states[0].predators, self.predator_color) + prey_quiver = draw_agents(ax, states[0].prey, self.prey_color) + + def make_frame(state: State) -> Any: + # Rather than redraw just update the quivers properties + predators_quiver.set_offsets(state.predators.pos) + predators_quiver.set_UVC( + jnp.cos(state.predators.heading), jnp.sin(state.predators.heading) + ) + prey_quiver.set_offsets(state.prey.pos) + prey_quiver.set_UVC( + jnp.cos(state.prey.heading), jnp.sin(state.prey.heading) + ) + return ((predators_quiver, prey_quiver),) + + matplotlib.rc("animation", html="jshtml") + self._animation = matplotlib.animation.FuncAnimation( + fig, + make_frame, + frames=states, + interval=interval, + blit=False, + ) + + if save_path: + self._animation.save(save_path) + + return self._animation + + def close(self) -> None: + """Perform any necessary cleanup. + + Environments will automatically :meth:`close()` themselves when + garbage collected or when the program exits. + """ + plt.close(self._figure_name) + + def _draw(self, ax: plt.Axes, state: State) -> None: + ax.clear() + draw_agents(ax, state.predators, self.predator_color) + draw_agents(ax, state.prey, self.prey_color) + + def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: + exists = plt.fignum_exists(self._figure_name) + if exists: + fig = plt.figure(self._figure_name) + ax = fig.get_axes()[0] + else: + fig = plt.figure(self._figure_name, figsize=self._figure_size) + fig.set_layout_engine( + layout=TightLayoutEngine(pad=False, w_pad=0.0, h_pad=0.0) + ) + if not plt.isinteractive(): + fig.show() + ax = fig.add_subplot() + + fig, ax = format_plot(fig, ax) + return fig, ax + + def _update_display(self, fig: plt.Figure) -> None: + if plt.isinteractive(): + # Required to update render when using Jupyter Notebook. + fig.canvas.draw() + if jumanji.environments.is_colab(): + plt.show(self._figure_name) + else: + # Required to update render when not using Jupyter Notebook. + fig.canvas.draw_idle() + fig.canvas.flush_events() + + def _clear_display(self) -> None: + if jumanji.environments.is_colab(): + import IPython.display + + IPython.display.clear_output(True) diff --git a/mkdocs.yml b/mkdocs.yml index 928da794b..624f35ba0 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -42,6 +42,8 @@ nav: - Sokoban: environments/sokoban.md - Snake: environments/snake.md - TSP: environments/tsp.md + - Swarms: + - PredatorPrey: environments/predator_prey.md - User Guides: - Advanced Usage: guides/advanced_usage.md - Registration: guides/registration.md diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 2e398c025..2e4054474 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,5 +1,6 @@ chex>=0.1.3 dm-env>=1.5 +esquilax>=1.0.2 gym>=0.22.0 huggingface-hub jax>=0.2.26 From 988339b7829d88cf5c2092aeadf61ccb431489ed Mon Sep 17 00:00:00 2001 From: zombie-einstein <13398815+zombie-einstein@users.noreply.github.com> Date: Tue, 5 Nov 2024 14:45:27 +0000 Subject: [PATCH 02/19] fix: PR fixes (#2) * refactor: Formatting fixes * fix: Implement rewards as class * refactor: Implement observation as NamedTuple * refactor: Implement initial state generator * docs: Update docstrings * refactor: Add env animate method * docs: Link env into API docs --- docs/api/environments/predator_prey.md | 11 + docs/environments/predator_prey.md | 19 +- jumanji/environments/swarms/common/types.py | 8 +- jumanji/environments/swarms/common/updates.py | 31 +- jumanji/environments/swarms/common/viewer.py | 2 +- .../environments/swarms/predator_prey/env.py | 165 ++++------- .../swarms/predator_prey/env_test.py | 28 +- .../swarms/predator_prey/generator.py | 70 +++++ .../swarms/predator_prey/rewards.py | 275 ++++++++++++++++++ .../swarms/predator_prey/types.py | 58 +++- .../swarms/predator_prey/updates.py | 129 -------- mkdocs.yml | 2 + 12 files changed, 496 insertions(+), 302 deletions(-) create mode 100644 docs/api/environments/predator_prey.md create mode 100644 jumanji/environments/swarms/predator_prey/generator.py create mode 100644 jumanji/environments/swarms/predator_prey/rewards.py delete mode 100644 jumanji/environments/swarms/predator_prey/updates.py diff --git a/docs/api/environments/predator_prey.md b/docs/api/environments/predator_prey.md new file mode 100644 index 000000000..52bf4e6e9 --- /dev/null +++ b/docs/api/environments/predator_prey.md @@ -0,0 +1,11 @@ +::: jumanji.environments.swarms.predator_prey.env.PredatorPrey + selection: + members: + - __init__ + - reset + - step + - observation_spec + - action_spec + - reward_spec + - render + - animate diff --git a/docs/environments/predator_prey.md b/docs/environments/predator_prey.md index bfd8f96c2..0a53ed927 100644 --- a/docs/environments/predator_prey.md +++ b/docs/environments/predator_prey.md @@ -35,19 +35,8 @@ parameters. Agents are restricted to velocities within a fixed range of speeds. ## Reward -Rewards can be either sparse or proximity-based. +Rewards are generated for each agent individually. They are generally dependent on proximity, so +their scale can depend on agent density and interaction ranges. -### Sparse - -- `predators`: jax array (float) of shape `(num_predators,)`, predators are rewarded a fixed amount - for coming into contact with a prey agent. If they are in contact with multiple prey, only the - nearest agent is selected. -- `prey`: jax array (float) of shape `(num_predators,)`, prey are penalised a fix negative amount if - they come into contact with a predator agent. - -### Proximity - -- `predators`: jax array (float) of shape `(num_predators,)`, predators are rewarded with an amount - scaled linearly with the distance to the prey agents, summed over agents in range. -- `prey`: jax array (float) of shape `(num_predators,)`, prey are penalised by an amount scaled linearly - with distance from predator agents, summed over all predators in range. +- `predators`: jax array (float) of shape `(num_predators,)`, individual predator agent rewards. +- `prey`: jax array (float) of shape `(num_prey,)`, individual prey rewards. diff --git a/jumanji/environments/swarms/common/types.py b/jumanji/environments/swarms/common/types.py index 82ec12c87..db3e0e0f4 100644 --- a/jumanji/environments/swarms/common/types.py +++ b/jumanji/environments/swarms/common/types.py @@ -21,7 +21,7 @@ import chex -@dataclass +@dataclass(frozen=True) class AgentParams: """ max_rotate: Max angle an agent can rotate during a step (a fraction of pi) @@ -48,6 +48,6 @@ class AgentState: speed: Speed of the agents """ - pos: chex.Array - heading: chex.Array - speed: chex.Array + pos: chex.Array # (num_agents, 2) + heading: chex.Array # (num_agents,) + speed: chex.Array # (num_agents,) diff --git a/jumanji/environments/swarms/common/updates.py b/jumanji/environments/swarms/common/updates.py index e10f7d733..de0b21138 100644 --- a/jumanji/environments/swarms/common/updates.py +++ b/jumanji/environments/swarms/common/updates.py @@ -19,8 +19,7 @@ import jax import jax.numpy as jnp -from . import types -from .types import AgentParams +from jumanji.environments.swarms.common import types @esquilax.transforms.amap @@ -55,22 +54,18 @@ def update_velocity( return new_heading, new_speeds -@esquilax.transforms.amap -def move( - _: chex.PRNGKey, _params: None, x: Tuple[chex.Array, float, float] -) -> chex.Array: +def move(pos: chex.Array, heading: chex.Array, speed: chex.Array) -> chex.Array: """ Get updated agent positions from current speed and heading Args: - _: Dummy JAX random key. - _params: unused parameters. - x: Tuple containing current agent position, heading and speed. + pos: Agent position + heading: Agent heading (angle). + speed: Agent speed Returns: jax array (float32): Updated agent position """ - pos, heading, speed = x d_pos = jnp.array([speed * jnp.cos(heading), speed * jnp.sin(heading)]) return (pos + d_pos) % 1.0 @@ -95,7 +90,7 @@ def init_state( speeds = jax.random.uniform( k2, (n,), minval=params.min_speed, maxval=params.max_speed ) - headings = jax.random.uniform(k3, (n,), minval=0.0, maxval=2.0 * jax.numpy.pi) + headings = jax.random.uniform(k3, (n,), minval=0.0, maxval=2.0 * jnp.pi) return types.AgentState( pos=positions, @@ -105,7 +100,10 @@ def init_state( def update_state( - key: chex.PRNGKey, params: AgentParams, state: types.AgentState, actions: chex.Array + key: chex.PRNGKey, + params: types.AgentParams, + state: types.AgentState, + actions: chex.Array, ) -> types.AgentState: """ Update the state of a group of agents from a sample of actions @@ -120,9 +118,9 @@ def update_state( AgentState: Updated state of the agents after applying steering actions and updating positions. """ - actions = jax.numpy.clip(actions, min=-1.0, max=1.0) + actions = jnp.clip(actions, min=-1.0, max=1.0) headings, speeds = update_velocity(key, params, (actions, state)) - positions = move(key, None, (state.pos, headings, speeds)) + positions = jax.vmap(move)(state.pos, headings, speeds) return types.AgentState( pos=positions, @@ -132,7 +130,7 @@ def update_state( def view( - _: chex.PRNGKey, + _key: chex.PRNGKey, params: Tuple[float, float], a: types.AgentState, b: types.AgentState, @@ -151,7 +149,8 @@ def view( Currently, this model assumes the viewed objects are circular. Args: - _: Dummy JAX random key. + _key: Dummy JAX random key, required by esquilax API, but + not used during the interaction. params: Tuple containing agent view angle and view-radius. a: Viewing agent state. b: State of agent being viewed. diff --git a/jumanji/environments/swarms/common/viewer.py b/jumanji/environments/swarms/common/viewer.py index 4da841267..7a5f029b5 100644 --- a/jumanji/environments/swarms/common/viewer.py +++ b/jumanji/environments/swarms/common/viewer.py @@ -19,7 +19,7 @@ from matplotlib.figure import Figure from matplotlib.quiver import Quiver -from .types import AgentState +from jumanji.environments.swarms.common.types import AgentState def draw_agents(ax: Axes, agent_states: AgentState, color: str) -> Quiver: diff --git a/jumanji/environments/swarms/predator_prey/env.py b/jumanji/environments/swarms/predator_prey/env.py index a905e4569..92ec2fa2c 100644 --- a/jumanji/environments/swarms/predator_prey/env.py +++ b/jumanji/environments/swarms/predator_prey/env.py @@ -13,29 +13,33 @@ # limitations under the License. from functools import cached_property -from typing import Optional, Tuple +from typing import Optional, Sequence, Tuple import chex import jax import jax.numpy as jnp -from esquilax.transforms import nearest_neighbour, spatial +from esquilax.transforms import spatial +from matplotlib.animation import FuncAnimation from jumanji import specs from jumanji.env import Environment from jumanji.environments.swarms.common.types import AgentParams -from jumanji.environments.swarms.common.updates import init_state, update_state, view +from jumanji.environments.swarms.common.updates import update_state, view +from jumanji.environments.swarms.predator_prey.generator import ( + Generator, + RandomGenerator, +) +from jumanji.environments.swarms.predator_prey.rewards import DistanceRewards, RewardFn +from jumanji.environments.swarms.predator_prey.types import ( + Actions, + Observation, + Rewards, + State, +) +from jumanji.environments.swarms.predator_prey.viewer import PredatorPreyViewer from jumanji.types import TimeStep, restart, termination, transition from jumanji.viewer import Viewer -from .types import Actions, Observation, Rewards, State -from .updates import ( - distance_predator_rewards, - distance_prey_rewards, - sparse_predator_rewards, - sparse_prey_rewards, -) -from .viewer import PredatorPreyViewer - class PredatorPrey(Environment): """A predator and prey flock environment @@ -67,10 +71,9 @@ class PredatorPrey(Environment): - prey: jax array (float) of shape (num_prey, 2) - reward: `Rewards` - Arrays of individual agent rewards. Rewards can either be generated - sparsely applied when agent collide, or be generated based on distance - to other agents (hence they are dependent on the number and density - of agents). + Arrays of individual agent rewards. Rewards generally depend on + proximity to other agents, and so can vary dependent on + density and agent radius and vision ranges. - predators: jax array (float) of shape (num_predators,) - prey: jax array (float) of shape (num_prey,) @@ -133,8 +136,6 @@ def __init__( num_vision: int, agent_radius: float, sparse_rewards: bool, - prey_penalty: float, - predator_rewards: float, predator_max_rotate: float, predator_max_accelerate: float, predator_min_speed: float, @@ -147,6 +148,8 @@ def __init__( prey_view_angle: float, max_steps: int = 10_000, viewer: Optional[Viewer[State]] = None, + generator: Optional[Generator] = None, + reward_fn: Optional[RewardFn] = None, ) -> None: """Instantiates a `PredatorPrey` environment @@ -174,11 +177,6 @@ def __init__( when agents are within a fixed collision range. If `False` rewards are dependent on distance to other agents with vision range. - prey_penalty: Penalty to apply to prey agents if - they interact with predator agents. This - value is negated when applied. - predator_rewards: Rewards provided to predator agents - when they interact with prey agents. predator_max_rotate: Maximum rotation predator agents can turn within a step. Should be a value from [0,1] representing a fraction of pi radians. @@ -203,6 +201,8 @@ def __init__( relative to its heading. max_steps: Maximum number of environment steps before termination viewer: `Viewer` used for rendering. Defaults to `PredatorPreyViewer`. + generator: Initial state generator. Defaults to `RandomGenerator`. + reward_fn: Reward function. Defaults to `DistanceRewards`. """ self.num_predators = num_predators self.num_prey = num_prey @@ -211,8 +211,6 @@ def __init__( self.num_vision = num_vision self.agent_radius = agent_radius self.sparse_rewards = sparse_rewards - self.prey_penalty = prey_penalty - self.predator_rewards = predator_rewards self.predator_params = AgentParams( max_rotate=predator_max_rotate, max_accelerate=predator_max_accelerate, @@ -230,6 +228,10 @@ def __init__( self.max_steps = max_steps super().__init__() self._viewer = viewer or PredatorPreyViewer() + self._generator = generator or RandomGenerator(num_predators, num_prey) + self._reward_fn = reward_fn or DistanceRewards( + predator_vision_range, prey_vision_range, 1.0, 1.0 + ) def __repr__(self) -> str: return "\n".join( @@ -242,8 +244,8 @@ def __repr__(self) -> str: f" - num vision: {self.num_vision}" f" - agent radius: {self.agent_radius}" f" - sparse-rewards: {self.sparse_rewards}", - f" - prey-penalty: {self.prey_penalty}", - f" - predator-rewards: {self.predator_rewards}", + f" - generator: {self._generator.__class__.__name__}", + f" - reward-fn: {self._reward_fn.__class__.__name__}", ] ) @@ -257,12 +259,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: state: Agent states. timestep: TimeStep with individual agent local environment views. """ - key, predator_key, prey_key = jax.random.split(key, num=3) - predator_state = init_state( - self.num_predators, self.predator_params, predator_key - ) - prey_state = init_state(self.num_prey, self.prey_params, prey_key) - state = State(predators=predator_state, prey=prey_state, key=key) + state = self._generator(key, self.predator_params, self.prey_params) timestep = restart(observation=self._state_to_observation(state)) return state, timestep @@ -290,12 +287,7 @@ def step( state = State( predators=predators, prey=prey, key=state.key, step=state.step + 1 ) - - if self.sparse_rewards: - rewards = self._state_to_sparse_rewards(state) - else: - rewards = self._state_to_distance_rewards(state) - + rewards = self._reward_fn(state) observation = self._state_to_observation(state) timestep = jax.lax.cond( state.step >= self.max_steps, @@ -379,76 +371,6 @@ def _state_to_observation(self, state: State) -> Observation: prey=prey_obs, ) - def _state_to_sparse_rewards(self, state: State) -> Rewards: - prey_rewards = spatial( - sparse_prey_rewards, - reduction=jnp.add, - default=0.0, - include_self=False, - i_range=2 * self.agent_radius, - )( - state.key, - self.prey_penalty, - None, - None, - pos=state.prey.pos, - pos_b=state.predators.pos, - ) - predator_rewards = nearest_neighbour( - sparse_predator_rewards, - default=0.0, - i_range=2 * self.agent_radius, - )( - state.key, - self.predator_rewards, - None, - None, - pos=state.predators.pos, - pos_b=state.prey.pos, - ) - return Rewards( - predators=predator_rewards, - prey=prey_rewards, - ) - - def _state_to_distance_rewards(self, state: State) -> Rewards: - - prey_rewards = spatial( - distance_prey_rewards, - reduction=jnp.add, - default=0.0, - include_self=False, - i_range=self.prey_vision_range, - )( - state.key, - self.prey_penalty, - state.prey, - state.predators, - pos=state.prey.pos, - pos_b=state.predators.pos, - i_range=self.prey_vision_range, - ) - predator_rewards = spatial( - distance_predator_rewards, - reduction=jnp.add, - default=0.0, - include_self=False, - i_range=self.predator_vision_range, - )( - state.key, - self.predator_rewards, - state.predators, - state.prey, - pos=state.predators.pos, - pos_b=state.prey.pos, - i_range=self.prey_vision_range, - ) - - return Rewards( - predators=predator_rewards, - prey=prey_rewards, - ) - @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Returns the observation spec. @@ -514,9 +436,10 @@ def action_spec(self) -> specs.Spec[Actions]: @cached_property def reward_spec(self) -> specs.Spec[Rewards]: # type: ignore[override] - """Returns the reward spec + """Returns the reward spec. - Individual rewards for predator and prey types. + Arrays of individual rewards for both predator and + prey types. Returns: reward_spec: Predator-prey reward spec @@ -546,6 +469,26 @@ def render(self, state: State) -> None: """ self._viewer.render(state) + def animate( + self, + states: Sequence[State], + interval: int = 200, + save_path: Optional[str] = None, + ) -> FuncAnimation: + """Create an animation from a sequence of environment states. + + Args: + states: sequence of environment states corresponding to consecutive + timesteps. + interval: delay between frames in milliseconds. + save_path: the path where the animation file should be saved. If it + is None, the plot will not be saved. + + Returns: + Animation that can be saved as a GIF, MP4, or rendered with HTML. + """ + return self._viewer.animate(states, interval=interval, save_path=save_path) + def close(self) -> None: """Perform any necessary cleanup.""" self._viewer.close() diff --git a/jumanji/environments/swarms/predator_prey/env_test.py b/jumanji/environments/swarms/predator_prey/env_test.py index 16658b9bd..3a4bfdcdb 100644 --- a/jumanji/environments/swarms/predator_prey/env_test.py +++ b/jumanji/environments/swarms/predator_prey/env_test.py @@ -23,6 +23,10 @@ from jumanji.environments.swarms.common.types import AgentState from jumanji.environments.swarms.predator_prey import PredatorPrey +from jumanji.environments.swarms.predator_prey.rewards import ( + DistanceRewards, + SparseRewards, +) from jumanji.environments.swarms.predator_prey.types import ( Actions, Observation, @@ -35,8 +39,11 @@ ) from jumanji.types import StepType, TimeStep +PREDATOR_VISION_RANGE = 0.2 +PREY_VISION_RANGE = 0.1 PREDATOR_REWARD = 0.2 PREY_PENALTY = 0.1 +AGENT_RADIUS = 0.05 @pytest.fixture @@ -44,13 +51,11 @@ def env() -> PredatorPrey: return PredatorPrey( num_predators=2, num_prey=10, - prey_vision_range=0.1, - predator_vision_range=0.2, + prey_vision_range=PREY_VISION_RANGE, + predator_vision_range=PREDATOR_VISION_RANGE, num_vision=11, - agent_radius=0.05, + agent_radius=AGENT_RADIUS, sparse_rewards=True, - prey_penalty=PREY_PENALTY, - predator_rewards=PREDATOR_REWARD, predator_max_rotate=0.1, predator_max_accelerate=0.01, predator_min_speed=0.01, @@ -282,7 +287,6 @@ def test_view_observations( ], ) def test_sparse_rewards( - env: PredatorPrey, predator_pos: List[float], predator_reward: float, prey_pos: List[float], @@ -306,7 +310,9 @@ def test_sparse_rewards( key=jax.random.PRNGKey(101), ) - rewards = env._state_to_sparse_rewards(state) + reward_fn = SparseRewards(AGENT_RADIUS, PREDATOR_REWARD, PREY_PENALTY) + rewards = reward_fn(state) + assert isinstance(rewards, Rewards) assert rewards.predators[0] == predator_reward assert rewards.prey[0] == prey_reward @@ -321,7 +327,6 @@ def test_sparse_rewards( ], ) def test_distance_rewards( - env: PredatorPrey, predator_pos: List[float], predator_reward: float, prey_pos: List[float], @@ -345,7 +350,10 @@ def test_distance_rewards( key=jax.random.PRNGKey(101), ) - rewards = env._state_to_distance_rewards(state) + reward_fn = DistanceRewards( + PREDATOR_VISION_RANGE, PREY_VISION_RANGE, PREDATOR_REWARD, PREY_PENALTY + ) + rewards = reward_fn(state) assert isinstance(rewards, Rewards) assert jnp.isclose(rewards.predators[0], predator_reward) assert jnp.isclose(rewards.prey[0], prey_reward) @@ -372,7 +380,7 @@ def test_snake__animation(env: PredatorPrey, tmpdir: py.path.local) -> None: action = env.action_spec.generate_value() state, _ = step_fn(state, action) states.append(state) - animation = env._viewer.animate(states, 200, None) + animation = env.animate(states, interval=200, save_path=None) assert isinstance(animation, matplotlib.animation.Animation) path = str(tmpdir.join("/anim.gif")) diff --git a/jumanji/environments/swarms/predator_prey/generator.py b/jumanji/environments/swarms/predator_prey/generator.py new file mode 100644 index 000000000..8c5549f9e --- /dev/null +++ b/jumanji/environments/swarms/predator_prey/generator.py @@ -0,0 +1,70 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +import chex +import jax + +from jumanji.environments.swarms.common.types import AgentParams +from jumanji.environments.swarms.common.updates import init_state +from jumanji.environments.swarms.predator_prey.types import State + + +class Generator(abc.ABC): + def __init__(self, num_predators: int, num_prey: int) -> None: + """Interface for instance generation for the `PredatorPrey` environment. + + Args: + num_predators: Number of predator agents + num_prey: Number of prey agents + """ + self.num_predators = num_predators + self.num_prey = num_prey + + @abc.abstractmethod + def __call__( + self, key: chex.PRNGKey, predator_params: AgentParams, prey_params: AgentParams + ) -> State: + """Generate initial agent positions and velocities. + + Args: + key: random key. + predator_params: Predator `AgentParams`. + prey_params: Prey `AgentParams`. + + Returns: + Initial agent `State`. + """ + + +class RandomGenerator(Generator): + def __call__( + self, key: chex.PRNGKey, predator_params: AgentParams, prey_params: AgentParams + ) -> State: + """Generate random initial agent positions and velocities. + + Args: + key: random key. + predator_params: Predator `AgentParams`. + prey_params: Prey `AgentParams`. + + Returns: + state: the generated state. + """ + key, predator_key, prey_key = jax.random.split(key, num=3) + predator_state = init_state(self.num_predators, predator_params, predator_key) + prey_state = init_state(self.num_prey, prey_params, prey_key) + state = State(predators=predator_state, prey=prey_state, key=key) + return state diff --git a/jumanji/environments/swarms/predator_prey/rewards.py b/jumanji/environments/swarms/predator_prey/rewards.py new file mode 100644 index 000000000..50f81c8b6 --- /dev/null +++ b/jumanji/environments/swarms/predator_prey/rewards.py @@ -0,0 +1,275 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Union + +import chex +import jax.numpy as jnp +from esquilax.transforms import nearest_neighbour, spatial +from esquilax.utils import shortest_distance + +from jumanji.environments.swarms.predator_prey.types import Rewards, State + + +class RewardFn(abc.ABC): + """Abstract class for `PredatorPrey` rewards.""" + + @abc.abstractmethod + def __call__(self, state: State) -> Rewards: + """The reward function used in the `PredatorPrey` environment. + + Args: + state: `PredatorPrey` state. + + Returns: + The reward for the current step for individual agents. + """ + + +class SparseRewards(RewardFn): + """Sparse rewards applied when agents come into contact. + + Rewards applied when predators and prey come into contact + (i.e. overlap), positively rewarding predators and negatively + penalising prey. Attempts to model predators `capturing` prey. + """ + + def __init__( + self, agent_radius: float, predator_reward: float, prey_penalty: float + ) -> None: + """ + Initialise a sparse reward function. + + Args: + agent_radius: Radius of simulated agents. + predator_reward: Predator reward value. + prey_penalty: Prey penalty (this is negated when applied). + """ + self.agent_radius = agent_radius + self.prey_penalty = prey_penalty + self.predator_reward = predator_reward + + def prey_rewards( + self, + _key: chex.PRNGKey, + _params: None, + _prey: None, + _predator: None, + ) -> float: + """Penalise a prey agent if contacted by a predator agent. + + Apply a negative penalty to prey agents that collide + with a prey agent. This function is applied using an + Esquilax spatial interaction which accumulates rewards. + + Args: + _key: Dummy JAX random key . + _params: Dummy params (required by Esquilax). + _prey: Dummy agent-state (required by Esquilax). + _predator: Dummy agent-state (required by Esquilax). + + Returns: + float: Negative penalty applied to prey agent. + """ + return -self.prey_penalty + + def predator_rewards( + self, + _key: chex.PRNGKey, + _params: None, + _predator: None, + _prey: None, + ) -> float: + """Reward a predator agent if it is within range of a prey agent + (required by Esquilax) + Apply a fixed positive reward if a predator agent is within + a fixed range of a prey-agent. This function can + be used with an Esquilax spatial interaction to + apply rewards to agents in range. + + Args: + _key: Dummy JAX random key (required by Esquilax). + _params: Dummy params (required by Esquilax). + _prey: Dummy agent-state (required by Esquilax). + _predator: Dummy agent-state (required by Esquilax). + + Returns: + float: Predator agent reward. + """ + return self.predator_reward + + def __call__(self, state: State) -> Rewards: + prey_rewards = spatial( + self.prey_rewards, + reduction=jnp.add, + default=0.0, + include_self=False, + i_range=2 * self.agent_radius, + )( + state.key, + None, + None, + None, + pos=state.prey.pos, + pos_b=state.predators.pos, + ) + predator_rewards = nearest_neighbour( + self.predator_rewards, + default=0.0, + i_range=2 * self.agent_radius, + )( + state.key, + None, + None, + None, + pos=state.predators.pos, + pos_b=state.prey.pos, + ) + return Rewards( + predators=predator_rewards, + prey=prey_rewards, + ) + + +class DistanceRewards(RewardFn): + """Rewards based on proximity to other agents. + + Rewards generated based on an agents proximity to other + agents within their vision range. Predator rewards increase + as they get closer to prey, and prey rewards become + increasingly negative as they get closer to predators. + Rewards are summed over all other agents within range of + an agent. + """ + + def __init__( + self, + predator_vision_range: float, + prey_vision_range: float, + predator_reward: float, + prey_penalty: float, + ) -> None: + """ + Initialise a distance reward function + + Args: + predator_vision_range: Predator agent vision range. + prey_vision_range: Prey agent vision range. + predator_reward: Max reward value applied to + predator agents. + prey_penalty: Max reward value applied to prey agents + (this value is negated when applied). + """ + self.predator_vision_range = predator_vision_range + self.prey_vision_range = prey_vision_range + self.prey_penalty = prey_penalty + self.predator_reward = predator_reward + + def prey_rewards( + self, + _key: chex.PRNGKey, + _params: None, + prey_pos: chex.Array, + predator_pos: chex.Array, + *, + i_range: float, + ) -> Union[float, chex.Array]: + """Penalise a prey agent based on distance from a predator agent. + + Apply a negative penalty based on a distance between + agents. The penalty is a linear function of distance, + 0 at max distance up to `-penalty` at 0 distance. This function + can be used with an Esquilax spatial interaction to accumulate + rewards between agents. + + Args: + _key: Dummy JAX random key (required by Esquilax). + _params: Dummy params (required by Esquilax). + prey_pos: Prey positions. + predator_pos: Predator positions. + i_range: Static interaction range. + + Returns: + float: Agent rewards. + """ + d = shortest_distance(prey_pos, predator_pos) / i_range + return self.prey_penalty * (d - 1.0) + + def predator_rewards( + self, + _key: chex.PRNGKey, + _params: None, + predator_pos: chex.Array, + prey_pos: chex.Array, + *, + i_range: float, + ) -> Union[float, chex.Array]: + """Reward a predator agent based on distance from a prey agent. + + Apply a positive reward based on the linear distance between + a predator and prey agent. Rewards are zero at the max + interaction distance, and maximal at 0 range. This function + can be used with an Esquilax spatial interaction to accumulate + rewards between agents. + + Args: + _key: Dummy JAX random key (required by Esquilax). + _params: Dummy parameters (required by Esquilax). + predator_pos: Predator position. + prey_pos: Prey position. + i_range: Static interaction range. + + Returns: + float: Predator agent reward. + """ + d = shortest_distance(predator_pos, prey_pos) / i_range + return self.predator_reward * (1.0 - d) + + def __call__(self, state: State) -> Rewards: + prey_rewards = spatial( + self.prey_rewards, + reduction=jnp.add, + default=0.0, + include_self=False, + i_range=self.prey_vision_range, + )( + state.key, + None, + state.prey.pos, + state.predators.pos, + pos=state.prey.pos, + pos_b=state.predators.pos, + i_range=self.prey_vision_range, + ) + predator_rewards = spatial( + self.predator_rewards, + reduction=jnp.add, + default=0.0, + include_self=False, + i_range=self.predator_vision_range, + )( + state.key, + None, + state.predators.pos, + state.prey.pos, + pos=state.predators.pos, + pos_b=state.prey.pos, + i_range=self.prey_vision_range, + ) + + return Rewards( + predators=predator_rewards, + prey=prey_rewards, + ) diff --git a/jumanji/environments/swarms/predator_prey/types.py b/jumanji/environments/swarms/predator_prey/types.py index 3b12fd1ea..d6680a09b 100644 --- a/jumanji/environments/swarms/predator_prey/types.py +++ b/jumanji/environments/swarms/predator_prey/types.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, NamedTuple if TYPE_CHECKING: from dataclasses import dataclass @@ -38,17 +38,6 @@ class State: step: int = 0 -@dataclass -class Observation: - """ - predators: Local view of predator agents. - prey: Local view of prey agents. - """ - - predators: chex.Array - prey: chex.Array - - @dataclass class Actions: """ @@ -56,8 +45,8 @@ class Actions: prey: Array of actions for prey agents. """ - predators: chex.Array - prey: chex.Array + predators: chex.Array # (num_predators, 2) + prey: chex.Array # (num_prey, 2) @dataclass @@ -67,5 +56,42 @@ class Rewards: prey: Array of individual rewards for prey agents. """ - predators: chex.Array - prey: chex.Array + predators: chex.Array # (num_predators,) + prey: chex.Array # (num_prey,) + + +class Observation(NamedTuple): + """ + Individual observations for predator and prey agents. + + Each agent generates an independent observation, an array of + values representing the distance along a ray from the agent to + the nearest neighbour, with each cell representing a ray angle + (with `num_vision` rays evenly distributed over the agents + field of vision). Prey and prey agent types are visualised + independently to allow agents to observe both local position and type. + + For example if a prey agent sees a predator straight ahead and + `num_vision = 5` then the observation array could be + + ``` + [1.0, 1.0, 0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + ``` + + or if it observes another prey agent + + ``` + [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.5, 1.0, 1.0] + ``` + + where `1.0` indicates there is no agents along that ray, + and `0.5` is the normalised distance to the other agent. + + - `predators`: jax array (float) of shape `(num_predators, 2 * num_vision)` + in the unit interval. + - `prey`: jax array (float) of shape `(num_prey, 2 * num_vision)` in the + unit interval. + """ + + predators: chex.Array # (num_predators, num_vision) + prey: chex.Array # (num_prey, num_vision) diff --git a/jumanji/environments/swarms/predator_prey/updates.py b/jumanji/environments/swarms/predator_prey/updates.py deleted file mode 100644 index 75bc536d0..000000000 --- a/jumanji/environments/swarms/predator_prey/updates.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright 2022 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Optional, Union - -import chex -import esquilax - -from jumanji.environments.swarms.common import types - - -def sparse_prey_rewards( - _k: chex.PRNGKey, - penalty: float, - _prey: Optional[types.AgentState], - _predator: Optional[types.AgentState], -) -> float: - """Penalise a prey agent if contacted by a predator agent. - - Apply a negative penalty to prey agents that collide - with a prey agent. This function is applied using an - Esquilax spatial interaction. - - Args: - _k: Dummy JAX random key. - penalty: Penalty value. - _prey: Optional unused prey agent-state. - _predator: Optional unused predator agent-state. - - Returns: - float: Negative penalty applied to prey agent. - """ - return -penalty - - -def distance_prey_rewards( - _k: chex.PRNGKey, - penalty: float, - prey: types.AgentState, - predator: types.AgentState, - *, - i_range: float, -) -> Union[float, chex.Array]: - """Penalise a prey agent based on distance from a predator agent. - - Apply a negative penalty based on a distance between - agents. The penalty is a linear function of distance, - 0 at max distance up to `-penalty` at 0 distance. This function - can be used with an Esquilax spatial interaction to accumulate - rewards between agents. - - Args: - _k: Dummy JAX random key. - penalty: Maximum penalty applied. - prey: Prey agent-state. - predator: Predator agent-state. - i_range: Static interaction range. - - Returns: - float: Agent rewards. - """ - d = esquilax.utils.shortest_distance(prey.pos, predator.pos) / i_range - return penalty * (d - 1.0) - - -def sparse_predator_rewards( - _k: chex.PRNGKey, - reward: float, - _predator: Optional[types.AgentState], - _prey: Optional[types.AgentState], -) -> float: - """Reward a predator agent if it is within range of a prey agent - - Apply a fixed positive reward if a predator agent is within - a fixed range of a prey-agent. This function can - be used with an Esquilax spatial interaction to - apply rewards to agents in range. - - Args: - _k: Dummy JAX random key. - reward: Reward value to apply. - _predator: Optional unused agent-state. - _prey: Optional unused agent-state. - - Returns: - float: Predator agent reward. - """ - return reward - - -def distance_predator_rewards( - _k: chex.PRNGKey, - reward: float, - predator: types.AgentState, - prey: types.AgentState, - *, - i_range: float, -) -> Union[float, chex.Array]: - """Reward a predator agent based on distance from a prey agent. - - Apply a positive reward based on the linear distance between - a predator and prey agent. Rewards are zero at the max - interaction distance, and maximal at 0 range. This function - can be used with an Esquilax spatial interaction to accumulate - rewards between agents. - - Args: - _k: Dummy JAX random key. - reward: Maximum reward value. - predator: Predator agent-state. - prey: Prey agent-state. - i_range: Static interaction range. - - Returns: - float@ Predator agent reward. - """ - d = esquilax.utils.shortest_distance(predator.pos, prey.pos) / i_range - return reward * (1.0 - d) diff --git a/mkdocs.yml b/mkdocs.yml index fc2ef282e..83ccfb049 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -79,6 +79,8 @@ nav: - Sokoban: api/environments/sokoban.md - Snake: api/environments/snake.md - TSP: api/environments/tsp.md + - Swarms: + - PredatorPrey: api/environments/predator_prey.md - Wrappers: api/wrappers.md - Types: api/types.md From b4cce01e1c6fce7ab2d0b08b9d22197cd86818f1 Mon Sep 17 00:00:00 2001 From: zombie-einstein <13398815+zombie-einstein@users.noreply.github.com> Date: Wed, 6 Nov 2024 21:27:04 +0000 Subject: [PATCH 03/19] style: Run updated pre-commit --- jumanji/environments/swarms/common/updates.py | 8 ++---- .../environments/swarms/predator_prey/env.py | 13 +++------- .../swarms/predator_prey/env_test.py | 26 +++++-------------- .../swarms/predator_prey/rewards.py | 4 +-- .../swarms/predator_prey/viewer.py | 12 +++------ 5 files changed, 16 insertions(+), 47 deletions(-) diff --git a/jumanji/environments/swarms/common/updates.py b/jumanji/environments/swarms/common/updates.py index de0b21138..353fdcf9e 100644 --- a/jumanji/environments/swarms/common/updates.py +++ b/jumanji/environments/swarms/common/updates.py @@ -70,9 +70,7 @@ def move(pos: chex.Array, heading: chex.Array, speed: chex.Array) -> chex.Array: return (pos + d_pos) % 1.0 -def init_state( - n: int, params: types.AgentParams, key: chex.PRNGKey -) -> types.AgentState: +def init_state(n: int, params: types.AgentParams, key: chex.PRNGKey) -> types.AgentState: """ Randomly initialise state of a group of agents @@ -87,9 +85,7 @@ def init_state( k1, k2, k3 = jax.random.split(key, 3) positions = jax.random.uniform(k1, (n, 2)) - speeds = jax.random.uniform( - k2, (n,), minval=params.min_speed, maxval=params.max_speed - ) + speeds = jax.random.uniform(k2, (n,), minval=params.min_speed, maxval=params.max_speed) headings = jax.random.uniform(k3, (n,), minval=0.0, maxval=2.0 * jnp.pi) return types.AgentState( diff --git a/jumanji/environments/swarms/predator_prey/env.py b/jumanji/environments/swarms/predator_prey/env.py index 92ec2fa2c..00923cc8d 100644 --- a/jumanji/environments/swarms/predator_prey/env.py +++ b/jumanji/environments/swarms/predator_prey/env.py @@ -263,9 +263,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep = restart(observation=self._state_to_observation(state)) return state, timestep - def step( - self, state: State, action: Actions - ) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: Actions) -> Tuple[State, TimeStep[Observation]]: """Environment update Update agent velocities and consequently their positions, @@ -279,14 +277,10 @@ def step( state: Updated agent positions and velocities. timestep: Transition timestep with individual agent local observations. """ - predators = update_state( - state.key, self.predator_params, state.predators, action.predators - ) + predators = update_state(state.key, self.predator_params, state.predators, action.predators) prey = update_state(state.key, self.prey_params, state.prey, action.prey) - state = State( - predators=predators, prey=prey, key=state.key, step=state.step + 1 - ) + state = State(predators=predators, prey=prey, key=state.key, step=state.step + 1) rewards = self._reward_fn(state) observation = self._state_to_observation(state) timestep = jax.lax.cond( @@ -299,7 +293,6 @@ def step( return state, timestep def _state_to_observation(self, state: State) -> Observation: - prey_obs_predators = spatial( view, reduction=jnp.minimum, diff --git a/jumanji/environments/swarms/predator_prey/env_test.py b/jumanji/environments/swarms/predator_prey/env_test.py index 3a4bfdcdb..ca6d3a706 100644 --- a/jumanji/environments/swarms/predator_prey/env_test.py +++ b/jumanji/environments/swarms/predator_prey/env_test.py @@ -114,9 +114,7 @@ def step( k, state = carry k, k_pred, k_prey = jax.random.split(k, num=3) actions = Actions( - predators=jax.random.uniform( - k_pred, (env.num_predators, 2), minval=-1.0, maxval=1.0 - ), + predators=jax.random.uniform(k_pred, (env.num_predators, 2), minval=-1.0, maxval=1.0), prey=jax.random.uniform(k_prey, (env.num_prey, 2), minval=-1.0, maxval=1.0), ) new_state, timestep = env.step(state, actions) @@ -130,9 +128,7 @@ def step( assert isinstance(state_history, State) assert state_history.predators.pos.shape == (n_steps, env.num_predators, 2) - assert jnp.all( - (0.0 <= state_history.predators.pos) & (state_history.predators.pos <= 1.0) - ) + assert jnp.all((0.0 <= state_history.predators.pos) & (state_history.predators.pos <= 1.0)) assert state_history.predators.speed.shape == (n_steps, env.num_predators) assert jnp.all( (env.predator_params.min_speed <= state_history.predators.speed) @@ -140,8 +136,7 @@ def step( ) assert state_history.predators.speed.shape == (n_steps, env.num_predators) assert jnp.all( - (0.0 <= state_history.predators.heading) - & (state_history.predators.heading <= 2.0 * jnp.pi) + (0.0 <= state_history.predators.heading) & (state_history.predators.heading <= 2.0 * jnp.pi) ) assert state_history.prey.pos.shape == (n_steps, env.num_prey, 2) @@ -153,8 +148,7 @@ def step( ) assert state_history.prey.heading.shape == (n_steps, env.num_prey) assert jnp.all( - (0.0 <= state_history.prey.heading) - & (state_history.prey.heading <= 2.0 * jnp.pi) + (0.0 <= state_history.prey.heading) & (state_history.prey.heading <= 2.0 * jnp.pi) ) @@ -170,9 +164,7 @@ def select_action(action_key: chex.PRNGKey, _state: Observation) -> Actions: predators=jax.random.uniform( predator_key, (env.num_predators, 2), minval=-1.0, maxval=1.0 ), - prey=jax.random.uniform( - prey_key, (env.num_prey, 2), minval=-1.0, maxval=1.0 - ), + prey=jax.random.uniform(prey_key, (env.num_prey, 2), minval=-1.0, maxval=1.0), ) check_env_does_not_smoke(env, select_action=select_action) @@ -244,9 +236,7 @@ def test_view_observations( prey_speed = jnp.zeros(prey_heading.shape) state = State( - predators=AgentState( - pos=predator_pos, heading=predator_heading, speed=predator_speed - ), + predators=AgentState(pos=predator_pos, heading=predator_heading, speed=predator_speed), prey=AgentState(pos=prey_pos, heading=prey_heading, speed=prey_speed), key=jax.random.PRNGKey(101), ) @@ -359,9 +349,7 @@ def test_distance_rewards( assert jnp.isclose(rewards.prey[0], prey_reward) -def test_predator_prey_render( - monkeypatch: pytest.MonkeyPatch, env: PredatorPrey -) -> None: +def test_predator_prey_render(monkeypatch: pytest.MonkeyPatch, env: PredatorPrey) -> None: """Check that the render method builds the figure but does not display it.""" monkeypatch.setattr(plt, "show", lambda fig: None) step_fn = jax.jit(env.step) diff --git a/jumanji/environments/swarms/predator_prey/rewards.py b/jumanji/environments/swarms/predator_prey/rewards.py index 50f81c8b6..c51ad9b97 100644 --- a/jumanji/environments/swarms/predator_prey/rewards.py +++ b/jumanji/environments/swarms/predator_prey/rewards.py @@ -46,9 +46,7 @@ class SparseRewards(RewardFn): penalising prey. Attempts to model predators `capturing` prey. """ - def __init__( - self, agent_radius: float, predator_reward: float, prey_penalty: float - ) -> None: + def __init__(self, agent_radius: float, predator_reward: float, prey_penalty: float) -> None: """ Initialise a sparse reward function. diff --git a/jumanji/environments/swarms/predator_prey/viewer.py b/jumanji/environments/swarms/predator_prey/viewer.py index b94750a90..3b3422275 100644 --- a/jumanji/environments/swarms/predator_prey/viewer.py +++ b/jumanji/environments/swarms/predator_prey/viewer.py @@ -74,9 +74,7 @@ def animate( """ if not states: raise ValueError(f"The states argument has to be non-empty, got {states}.") - fig, ax = plt.subplots( - num=f"{self._figure_name}Anim", figsize=self._figure_size - ) + fig, ax = plt.subplots(num=f"{self._figure_name}Anim", figsize=self._figure_size) fig, ax = format_plot(fig, ax) predators_quiver = draw_agents(ax, states[0].predators, self.predator_color) @@ -89,9 +87,7 @@ def make_frame(state: State) -> Any: jnp.cos(state.predators.heading), jnp.sin(state.predators.heading) ) prey_quiver.set_offsets(state.prey.pos) - prey_quiver.set_UVC( - jnp.cos(state.prey.heading), jnp.sin(state.prey.heading) - ) + prey_quiver.set_UVC(jnp.cos(state.prey.heading), jnp.sin(state.prey.heading)) return ((predators_quiver, prey_quiver),) matplotlib.rc("animation", html="jshtml") @@ -128,9 +124,7 @@ def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: ax = fig.get_axes()[0] else: fig = plt.figure(self._figure_name, figsize=self._figure_size) - fig.set_layout_engine( - layout=TightLayoutEngine(pad=False, w_pad=0.0, h_pad=0.0) - ) + fig.set_layout_engine(layout=TightLayoutEngine(pad=False, w_pad=0.0, h_pad=0.0)) if not plt.isinteractive(): fig.show() ax = fig.add_subplot() From cb6d88d3dbc546623364758541f4852cfe226b03 Mon Sep 17 00:00:00 2001 From: zombie-einstein <13398815+zombie-einstein@users.noreply.github.com> Date: Thu, 7 Nov 2024 14:52:52 +0000 Subject: [PATCH 04/19] refactor: Consolidate predator prey type --- .../environments/swarms/predator_prey/env.py | 17 +++++++------- .../swarms/predator_prey/env_test.py | 13 +++++------ .../swarms/predator_prey/rewards.py | 12 +++++----- .../swarms/predator_prey/types.py | 22 ++++++------------- 4 files changed, 27 insertions(+), 37 deletions(-) diff --git a/jumanji/environments/swarms/predator_prey/env.py b/jumanji/environments/swarms/predator_prey/env.py index 00923cc8d..390d1da94 100644 --- a/jumanji/environments/swarms/predator_prey/env.py +++ b/jumanji/environments/swarms/predator_prey/env.py @@ -31,9 +31,8 @@ ) from jumanji.environments.swarms.predator_prey.rewards import DistanceRewards, RewardFn from jumanji.environments.swarms.predator_prey.types import ( - Actions, Observation, - Rewards, + PredatorPreyStruct, State, ) from jumanji.environments.swarms.predator_prey.viewer import PredatorPreyViewer @@ -61,7 +60,7 @@ class PredatorPrey(Environment): - predators: jax array (float) of shape (num_predators, 2 * num_vision) - prey: jax array (float) of shape (num_prey, 2 * num_vision) - - action: `Actions` + - action: `PredatorPreyStruct` Arrays of individual agent actions. Each agents actions rotate and accelerate/decelerate the agent as [rotation, acceleration] on the range [-1, 1]. These values are then scaled to update agent velocities within @@ -70,7 +69,7 @@ class PredatorPrey(Environment): - predators: jax array (float) of shape (num_predators, 2) - prey: jax array (float) of shape (num_prey, 2) - - reward: `Rewards` + - reward: `PredatorPreyStruct` Arrays of individual agent rewards. Rewards generally depend on proximity to other agents, and so can vary dependent on density and agent radius and vision ranges. @@ -263,7 +262,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep = restart(observation=self._state_to_observation(state)) return state, timestep - def step(self, state: State, action: Actions) -> Tuple[State, TimeStep[Observation]]: + def step(self, state: State, action: PredatorPreyStruct) -> Tuple[State, TimeStep[Observation]]: """Environment update Update agent velocities and consequently their positions, @@ -396,7 +395,7 @@ def observation_spec(self) -> specs.Spec[Observation]: ) @cached_property - def action_spec(self) -> specs.Spec[Actions]: + def action_spec(self) -> specs.Spec[PredatorPreyStruct]: """Returns the action spec. Arrays of individual agent actions. Each agents action is @@ -421,14 +420,14 @@ def action_spec(self) -> specs.Spec[Actions]: name="prey", ) return specs.Spec( - Actions, + PredatorPreyStruct, "ActionSpec", predators=predators, prey=prey, ) @cached_property - def reward_spec(self) -> specs.Spec[Rewards]: # type: ignore[override] + def reward_spec(self) -> specs.Spec[PredatorPreyStruct]: # type: ignore[override] """Returns the reward spec. Arrays of individual rewards for both predator and @@ -448,7 +447,7 @@ def reward_spec(self) -> specs.Spec[Rewards]: # type: ignore[override] name="prey", ) return specs.Spec( - Rewards, + PredatorPreyStruct, "rewardsSpec", predators=predators, prey=prey, diff --git a/jumanji/environments/swarms/predator_prey/env_test.py b/jumanji/environments/swarms/predator_prey/env_test.py index ca6d3a706..a75c9eaec 100644 --- a/jumanji/environments/swarms/predator_prey/env_test.py +++ b/jumanji/environments/swarms/predator_prey/env_test.py @@ -28,9 +28,8 @@ SparseRewards, ) from jumanji.environments.swarms.predator_prey.types import ( - Actions, Observation, - Rewards, + PredatorPreyStruct, State, ) from jumanji.testing.env_not_smoke import ( @@ -113,7 +112,7 @@ def step( ) -> Tuple[Tuple[chex.PRNGKey, State], Tuple[State, TimeStep[Observation]]]: k, state = carry k, k_pred, k_prey = jax.random.split(k, num=3) - actions = Actions( + actions = PredatorPreyStruct( predators=jax.random.uniform(k_pred, (env.num_predators, 2), minval=-1.0, maxval=1.0), prey=jax.random.uniform(k_prey, (env.num_prey, 2), minval=-1.0, maxval=1.0), ) @@ -158,9 +157,9 @@ def test_env_does_not_smoke(env: PredatorPrey, sparse_rewards: bool) -> None: env.sparse_rewards = sparse_rewards env.max_steps = 10 - def select_action(action_key: chex.PRNGKey, _state: Observation) -> Actions: + def select_action(action_key: chex.PRNGKey, _state: Observation) -> PredatorPreyStruct: predator_key, prey_key = jax.random.split(action_key) - return Actions( + return PredatorPreyStruct( predators=jax.random.uniform( predator_key, (env.num_predators, 2), minval=-1.0, maxval=1.0 ), @@ -303,7 +302,7 @@ def test_sparse_rewards( reward_fn = SparseRewards(AGENT_RADIUS, PREDATOR_REWARD, PREY_PENALTY) rewards = reward_fn(state) - assert isinstance(rewards, Rewards) + assert isinstance(rewards, PredatorPreyStruct) assert rewards.predators[0] == predator_reward assert rewards.prey[0] == prey_reward @@ -344,7 +343,7 @@ def test_distance_rewards( PREDATOR_VISION_RANGE, PREY_VISION_RANGE, PREDATOR_REWARD, PREY_PENALTY ) rewards = reward_fn(state) - assert isinstance(rewards, Rewards) + assert isinstance(rewards, PredatorPreyStruct) assert jnp.isclose(rewards.predators[0], predator_reward) assert jnp.isclose(rewards.prey[0], prey_reward) diff --git a/jumanji/environments/swarms/predator_prey/rewards.py b/jumanji/environments/swarms/predator_prey/rewards.py index c51ad9b97..bb489df79 100644 --- a/jumanji/environments/swarms/predator_prey/rewards.py +++ b/jumanji/environments/swarms/predator_prey/rewards.py @@ -20,14 +20,14 @@ from esquilax.transforms import nearest_neighbour, spatial from esquilax.utils import shortest_distance -from jumanji.environments.swarms.predator_prey.types import Rewards, State +from jumanji.environments.swarms.predator_prey.types import PredatorPreyStruct, State class RewardFn(abc.ABC): """Abstract class for `PredatorPrey` rewards.""" @abc.abstractmethod - def __call__(self, state: State) -> Rewards: + def __call__(self, state: State) -> PredatorPreyStruct: """The reward function used in the `PredatorPrey` environment. Args: @@ -108,7 +108,7 @@ def predator_rewards( """ return self.predator_reward - def __call__(self, state: State) -> Rewards: + def __call__(self, state: State) -> PredatorPreyStruct: prey_rewards = spatial( self.prey_rewards, reduction=jnp.add, @@ -135,7 +135,7 @@ def __call__(self, state: State) -> Rewards: pos=state.predators.pos, pos_b=state.prey.pos, ) - return Rewards( + return PredatorPreyStruct( predators=predator_rewards, prey=prey_rewards, ) @@ -235,7 +235,7 @@ def predator_rewards( d = shortest_distance(predator_pos, prey_pos) / i_range return self.predator_reward * (1.0 - d) - def __call__(self, state: State) -> Rewards: + def __call__(self, state: State) -> PredatorPreyStruct: prey_rewards = spatial( self.prey_rewards, reduction=jnp.add, @@ -267,7 +267,7 @@ def __call__(self, state: State) -> Rewards: i_range=self.prey_vision_range, ) - return Rewards( + return PredatorPreyStruct( predators=predator_rewards, prey=prey_rewards, ) diff --git a/jumanji/environments/swarms/predator_prey/types.py b/jumanji/environments/swarms/predator_prey/types.py index d6680a09b..c5b06e35b 100644 --- a/jumanji/environments/swarms/predator_prey/types.py +++ b/jumanji/environments/swarms/predator_prey/types.py @@ -39,25 +39,17 @@ class State: @dataclass -class Actions: +class PredatorPreyStruct: """ - predators: Array of actions for predator agents. - prey: Array of actions for prey agents. - """ - - predators: chex.Array # (num_predators, 2) - prey: chex.Array # (num_prey, 2) - + General struct for predator prey structured data + (e.g. rewards or observations) -@dataclass -class Rewards: - """ - predators: Array of individual rewards for predator agents. - prey: Array of individual rewards for prey agents. + predators: Array of data per predator agent. + prey: Array of data per prey agents. """ - predators: chex.Array # (num_predators,) - prey: chex.Array # (num_prey,) + predators: chex.Array # (num_predators, ...) + prey: chex.Array # (num_prey, ...) class Observation(NamedTuple): From 06de3a067b646ac2e0e959b917a62a87be8293ab Mon Sep 17 00:00:00 2001 From: zombie-einstein <13398815+zombie-einstein@users.noreply.github.com> Date: Mon, 11 Nov 2024 21:18:47 +0000 Subject: [PATCH 05/19] feat: Implement search and rescue (#3) * feat: Prototype search and rescue environment * test: Add additional tests * docs: Update docs * refactor: Update target plot color based on status * refactor: Formatting and fix remaining typos. --- ...{predator_prey.md => search_and_rescue.md} | 2 +- docs/environments/predator_prey.md | 42 -- docs/environments/search_and_rescue.md | 54 ++ jumanji/environments/__init__.py | 2 +- .../environments/swarms/predator_prey/env.py | 486 ------------------ .../swarms/predator_prey/env_test.py | 374 -------------- .../swarms/predator_prey/rewards.py | 273 ---------- .../__init__.py | 2 +- .../swarms/search_and_rescue/dynamics.py | 61 +++ .../swarms/search_and_rescue/env.py | 429 ++++++++++++++++ .../swarms/search_and_rescue/env_test.py | 270 ++++++++++ .../generator.py | 45 +- .../types.py | 54 +- .../viewer.py | 46 +- mkdocs.yml | 4 +- 15 files changed, 887 insertions(+), 1257 deletions(-) rename docs/api/environments/{predator_prey.md => search_and_rescue.md} (72%) delete mode 100644 docs/environments/predator_prey.md create mode 100644 docs/environments/search_and_rescue.md delete mode 100644 jumanji/environments/swarms/predator_prey/env.py delete mode 100644 jumanji/environments/swarms/predator_prey/env_test.py delete mode 100644 jumanji/environments/swarms/predator_prey/rewards.py rename jumanji/environments/swarms/{predator_prey => search_and_rescue}/__init__.py (94%) create mode 100644 jumanji/environments/swarms/search_and_rescue/dynamics.py create mode 100644 jumanji/environments/swarms/search_and_rescue/env.py create mode 100644 jumanji/environments/swarms/search_and_rescue/env_test.py rename jumanji/environments/swarms/{predator_prey => search_and_rescue}/generator.py (50%) rename jumanji/environments/swarms/{predator_prey => search_and_rescue}/types.py (55%) rename jumanji/environments/swarms/{predator_prey => search_and_rescue}/viewer.py (75%) diff --git a/docs/api/environments/predator_prey.md b/docs/api/environments/search_and_rescue.md similarity index 72% rename from docs/api/environments/predator_prey.md rename to docs/api/environments/search_and_rescue.md index 52bf4e6e9..0748af9c2 100644 --- a/docs/api/environments/predator_prey.md +++ b/docs/api/environments/search_and_rescue.md @@ -1,4 +1,4 @@ -::: jumanji.environments.swarms.predator_prey.env.PredatorPrey +::: jumanji.environments.swarms.search_and_rescue.env.SearchAndRescue selection: members: - __init__ diff --git a/docs/environments/predator_prey.md b/docs/environments/predator_prey.md deleted file mode 100644 index 0a53ed927..000000000 --- a/docs/environments/predator_prey.md +++ /dev/null @@ -1,42 +0,0 @@ -# Predator-Prey Flock Environment - -[//]: # (TODO: Add animated plot) - -Environment modelling two competing flocks/swarms of agents: - -- Predator agents are rewarded for contacting prey agents, or for proximity to prey agents. -- Prey agents are conversely penalised for being contacted by, or for proximity to predators. - -Each set of agents can consist of multiple agents, each independently -updated, and with their own independent observations. The agents occupy a square -space with periodic boundary conditions. Agents have a limited view range, i.e. they -only partially observe their local environment (and the locations of neighbouring agents within -range). Rewards are also assigned individually to each agent dependent on their local state. - -## Observation - -Each agent generates an independent observation, an array of values -representing the distance along a ray from the agent to the nearest neighbour, with -each cell representing a ray angle (with `num_vision` rays evenly distributed over the agents -field of vision). Prey and prey agent types are visualised independently to allow agents -to observe both local position and type. - -- `predators`: jax array (float) of shape `(num_predators, 2 * num_vision)` in the unit interval. -- `prey`: jax array (float) of shape `(num_prey, 2 * num_vision)` in the unit interval. - -## Action - -Agents can update their velocity each step by rotating and accelerating/decelerating. Values -are clipped to the range `[-1, 1]` and then scaled by max rotation and acceleration -parameters. Agents are restricted to velocities within a fixed range of speeds. - -- `predators`: jax array (float) of shape (num_predators, 2) each corresponding to `[rotation, acceleration]`. -- `prey`: jax array (float) of shape (num_prey, 2) each corresponding to `[rotation, acceleration]`. - -## Reward - -Rewards are generated for each agent individually. They are generally dependent on proximity, so -their scale can depend on agent density and interaction ranges. - -- `predators`: jax array (float) of shape `(num_predators,)`, individual predator agent rewards. -- `prey`: jax array (float) of shape `(num_prey,)`, individual prey rewards. diff --git a/docs/environments/search_and_rescue.md b/docs/environments/search_and_rescue.md new file mode 100644 index 000000000..a124b4470 --- /dev/null +++ b/docs/environments/search_and_rescue.md @@ -0,0 +1,54 @@ +# 🚁 Search & Rescue + +[//]: # (TODO: Add animated plot) + +Multi-agent environment, modelling a group of agents searching the environment +for multiple targets. Agents are individually rewarded for finding a target +that has not previously been detected. + +Each agent visualises a local region around it, creating a simple segmented view +of locations of other agents in the vicinity. The environment is updated in the +following sequence: + +- The velocity of searching agents are updated, and consequently their positions. +- The positions of targets are updated. +- Agents are rewarded for being within a fixed range of targets, and the target + being within its view cone. +- Targets within detection range and an agents view cone are marked as found. +- Local views of the environment are generated for each search agent. + +The agents are allotted a fixed number of steps to locate the targets. The search +space is a uniform space with unit dimensions, and wrapped at the boundaries. + +## Observations + +- `searcher_views`: jax array (float) of shape `(num_searchers, num_vision)`. Each agent + generates an independent observation, an array of values representing the distance + along a ray from the agent to the nearest neighbour, with each cell representing a + ray angle (with `num_vision` rays evenly distributed over the agents field of vision). + For example if an agent sees another agent straight ahead and `num_vision = 5` then + the observation array could be + + ``` + [1.0, 1.0, 0.5, 1.0, 1.0] + ``` + + where `1.0` indicates there is no agents along that ray, and `0.5` is the normalised + distance to the other agent. +- `target_remaining`: float in the range [0, 1]. The normalised number of targets + remaining to be detected (i.e. 1.0 when no targets have been found). +- `time_remaining`: float in the range [0, 1]. The normalised number of steps remaining + to locate the targets (i.e. 0.0 at the end of the episode). + +## Actions + +Jax array (float) of `(num_searchers, 2)` in the range [-1, 1]. Each entry in the +array represents an update of each agents velocity in the next step. Searching agents +update their velocity each step by rotating and accelerating/decelerating. Values +are clipped to the range `[-1, 1]` and then scaled by max rotation and acceleration +parameters. Agents are restricted to velocities within a fixed range of speeds. + +## Rewards + +Jax array (float) of `(num_searchers, 2)`. Rewards are generated for each agent individually. +Agents are rewarded 1.0 for locating a target that has not already been detected. diff --git a/jumanji/environments/__init__.py b/jumanji/environments/__init__.py index 3a4c4d58e..cb2518826 100644 --- a/jumanji/environments/__init__.py +++ b/jumanji/environments/__init__.py @@ -59,7 +59,7 @@ from jumanji.environments.routing.snake.env import Snake from jumanji.environments.routing.sokoban.env import Sokoban from jumanji.environments.routing.tsp.env import TSP -from jumanji.environments.swarms.predator_prey import PredatorPrey +from jumanji.environments.swarms.search_and_rescue.env import SearchAndRescue def is_colab() -> bool: diff --git a/jumanji/environments/swarms/predator_prey/env.py b/jumanji/environments/swarms/predator_prey/env.py deleted file mode 100644 index 390d1da94..000000000 --- a/jumanji/environments/swarms/predator_prey/env.py +++ /dev/null @@ -1,486 +0,0 @@ -# Copyright 2022 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import cached_property -from typing import Optional, Sequence, Tuple - -import chex -import jax -import jax.numpy as jnp -from esquilax.transforms import spatial -from matplotlib.animation import FuncAnimation - -from jumanji import specs -from jumanji.env import Environment -from jumanji.environments.swarms.common.types import AgentParams -from jumanji.environments.swarms.common.updates import update_state, view -from jumanji.environments.swarms.predator_prey.generator import ( - Generator, - RandomGenerator, -) -from jumanji.environments.swarms.predator_prey.rewards import DistanceRewards, RewardFn -from jumanji.environments.swarms.predator_prey.types import ( - Observation, - PredatorPreyStruct, - State, -) -from jumanji.environments.swarms.predator_prey.viewer import PredatorPreyViewer -from jumanji.types import TimeStep, restart, termination, transition -from jumanji.viewer import Viewer - - -class PredatorPrey(Environment): - """A predator and prey flock environment - - Environment modelling two swarms of agent types, predators - who are rewarded for avoiding pre agents, and conversely - prey agent who are rewarded for touching/catching - prey agents. Both agent types can consist of a large - number of individual agents, each with individual (local) - observations, actions, and rewards. Agents interact - on a uniform space with wrapped boundaries. - - - observation: `Observation` - Arrays representing each agent's local view of the environment. - Each cell of the array represent the distance from the agent - two the nearest other agents in the environment. Each agent type - is observed independently. - - - predators: jax array (float) of shape (num_predators, 2 * num_vision) - - prey: jax array (float) of shape (num_prey, 2 * num_vision) - - - action: `PredatorPreyStruct` - Arrays of individual agent actions. Each agents actions rotate and - accelerate/decelerate the agent as [rotation, acceleration] on the range - [-1, 1]. These values are then scaled to update agent velocities within - given parameters. - - - predators: jax array (float) of shape (num_predators, 2) - - prey: jax array (float) of shape (num_prey, 2) - - - reward: `PredatorPreyStruct` - Arrays of individual agent rewards. Rewards generally depend on - proximity to other agents, and so can vary dependent on - density and agent radius and vision ranges. - - - predators: jax array (float) of shape (num_predators,) - - prey: jax array (float) of shape (num_prey,) - - - state: `State` - - predators: `AgentState` - - pos: jax array (float) of shape (num_predators, 2) in the range [0, 1]. - - heading: jax array (float) of shape (num_predators,) in - the range [0, 2pi]. - - speed: jax array (float) of shape (num_predators,) in the - range [min_speed, max_speed]. - - prey: `AgentState` - - pos: jax array (float) of shape (num_prey, 2) in the range [0, 1]. - - heading: jax array (float) of shape (num_prey,) in - the range [0, 2pi]. - - speed: jax array (float) of shape (num_prey,) in the - range [min_speed, max_speed]. - - key: jax array (uint32) of shape (2,) - - step: int representing the current simulation step. - - - ```python - from jumanji.environments import PredatorPrey - env = PredatorPrey( - num_predators=2, - num_prey=10, - prey_vision_range=0.1, - predator_vision_range=0.1, - num_vision=10, - agent_radius=0.01, - sparse_rewards=True, - prey_penalty=0.1, - predator_rewards=0.2, - predator_max_rotate=0.1, - predator_max_accelerate=0.01, - predator_min_speed=0.01, - predator_max_speed=0.05, - predator_view_angle=0.5, - prey_max_rotate=0.1, - prey_max_accelerate=0.01, - prey_min_speed=0.01, - prey_max_speed=0.05, - prey_view_angle=0.5, - ) - key = jax.random.PRNGKey(0) - state, timestep = jax.jit(env.reset)(key) - env.render(state) - action = env.action_spec.generate_value() - state, timestep = jax.jit(env.step)(state, action) - env.render(state) - ``` - """ - - def __init__( - self, - num_predators: int, - num_prey: int, - prey_vision_range: float, - predator_vision_range: float, - num_vision: int, - agent_radius: float, - sparse_rewards: bool, - predator_max_rotate: float, - predator_max_accelerate: float, - predator_min_speed: float, - predator_max_speed: float, - predator_view_angle: float, - prey_max_rotate: float, - prey_max_accelerate: float, - prey_min_speed: float, - prey_max_speed: float, - prey_view_angle: float, - max_steps: int = 10_000, - viewer: Optional[Viewer[State]] = None, - generator: Optional[Generator] = None, - reward_fn: Optional[RewardFn] = None, - ) -> None: - """Instantiates a `PredatorPrey` environment - - Note: - The environment is square with dimensions - `[1.0, 1.0]` so parameters should be scaled - appropriately. Also note that performance is - dependent on agent vision and interaction ranges, - where larger values can lead to large number of - agent interactions. - - Args: - num_predators: Number of predator agents. - num_prey: Number of prey agents. - prey_vision_range: Prey agent vision range. - predator_vision_range: Predator agent vision range. - num_vision: Number of cells/subdivisions in agent - view models. Larger numbers provide a more accurate - view, at the cost of the environment, at the cost - of performance and memory usage. - agent_radius: Radius of individual agents. This - effects both agent collision range and how - large they appear to other agents. - sparse_rewards: If `True` fix rewards will be applied - when agents are within a fixed collision range. If - `False` rewards are dependent on distance to - other agents with vision range. - predator_max_rotate: Maximum rotation predator agents can - turn within a step. Should be a value from [0,1] - representing a fraction of pi radians. - predator_max_accelerate: Maximum acceleration/deceleration - a predator agent can apply within a step. - predator_min_speed: Minimum speed a predator agent can move at. - predator_max_speed: Maximum speed a predator agent can move at. - predator_view_angle: Predator agent local view angle. Should be - a value from [0,1] representing a fraction of pi radians. - The view cone of an agent goes from +- of the view angle - relative to its heading. - prey_max_rotate: Maximum rotation prey agents can - turn within a step. Should be a value from [0,1] - representing a fraction of pi radians. - prey_max_accelerate: Maximum acceleration/deceleration - a prey agent can apply within a step. - prey_min_speed: Minimum speed a prey agent can move at. - prey_max_speed: Maximum speed a prey agent can move at. - prey_view_angle: Prey agent local view angle. Should be - a value from [0,1] representing a fraction of pi radians. - The view cone of an agent goes from +- of the view angle - relative to its heading. - max_steps: Maximum number of environment steps before termination - viewer: `Viewer` used for rendering. Defaults to `PredatorPreyViewer`. - generator: Initial state generator. Defaults to `RandomGenerator`. - reward_fn: Reward function. Defaults to `DistanceRewards`. - """ - self.num_predators = num_predators - self.num_prey = num_prey - self.prey_vision_range = prey_vision_range - self.predator_vision_range = predator_vision_range - self.num_vision = num_vision - self.agent_radius = agent_radius - self.sparse_rewards = sparse_rewards - self.predator_params = AgentParams( - max_rotate=predator_max_rotate, - max_accelerate=predator_max_accelerate, - min_speed=predator_min_speed, - max_speed=predator_max_speed, - view_angle=predator_view_angle, - ) - self.prey_params = AgentParams( - max_rotate=prey_max_rotate, - max_accelerate=prey_max_accelerate, - min_speed=prey_min_speed, - max_speed=prey_max_speed, - view_angle=prey_view_angle, - ) - self.max_steps = max_steps - super().__init__() - self._viewer = viewer or PredatorPreyViewer() - self._generator = generator or RandomGenerator(num_predators, num_prey) - self._reward_fn = reward_fn or DistanceRewards( - predator_vision_range, prey_vision_range, 1.0, 1.0 - ) - - def __repr__(self) -> str: - return "\n".join( - [ - "Predator-prey flock environment:", - f" - num predators: {self.num_predators}", - f" - num prey: {self.num_prey}", - f" - prey vision range: {self.prey_vision_range}", - f" - predator vision range: {self.predator_vision_range}" - f" - num vision: {self.num_vision}" - f" - agent radius: {self.agent_radius}" - f" - sparse-rewards: {self.sparse_rewards}", - f" - generator: {self._generator.__class__.__name__}", - f" - reward-fn: {self._reward_fn.__class__.__name__}", - ] - ) - - def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: - """Randomly initialise predator and prey positions and velocities. - - Args: - key: Random key used to reset the environment. - - Returns: - state: Agent states. - timestep: TimeStep with individual agent local environment views. - """ - state = self._generator(key, self.predator_params, self.prey_params) - timestep = restart(observation=self._state_to_observation(state)) - return state, timestep - - def step(self, state: State, action: PredatorPreyStruct) -> Tuple[State, TimeStep[Observation]]: - """Environment update - - Update agent velocities and consequently their positions, - them generate new local views and rewards. - - Args: - state: Agent states. - action: Arrays of predator and prey individual actions. - - Returns: - state: Updated agent positions and velocities. - timestep: Transition timestep with individual agent local observations. - """ - predators = update_state(state.key, self.predator_params, state.predators, action.predators) - prey = update_state(state.key, self.prey_params, state.prey, action.prey) - - state = State(predators=predators, prey=prey, key=state.key, step=state.step + 1) - rewards = self._reward_fn(state) - observation = self._state_to_observation(state) - timestep = jax.lax.cond( - state.step >= self.max_steps, - termination, - transition, - rewards, - observation, - ) - return state, timestep - - def _state_to_observation(self, state: State) -> Observation: - prey_obs_predators = spatial( - view, - reduction=jnp.minimum, - default=jnp.ones((self.num_vision,)), - include_self=False, - i_range=self.prey_vision_range, - )( - state.key, - (self.prey_params.view_angle, self.agent_radius), - state.prey, - state.predators, - pos=state.prey.pos, - pos_b=state.predators.pos, - n_view=self.num_vision, - i_range=self.prey_vision_range, - ) - prey_obs_prey = spatial( - view, - reduction=jnp.minimum, - default=jnp.ones((self.num_vision,)), - include_self=False, - i_range=self.prey_vision_range, - )( - state.key, - (self.predator_params.view_angle, self.agent_radius), - state.prey, - state.prey, - pos=state.prey.pos, - n_view=self.num_vision, - i_range=self.prey_vision_range, - ) - predator_obs_prey = spatial( - view, - reduction=jnp.minimum, - default=jnp.ones((self.num_vision,)), - include_self=False, - i_range=self.predator_vision_range, - )( - state.key, - (self.predator_params.view_angle, self.agent_radius), - state.predators, - state.prey, - pos=state.predators.pos, - pos_b=state.prey.pos, - n_view=self.num_vision, - i_range=self.predator_vision_range, - ) - predator_obs_predator = spatial( - view, - reduction=jnp.minimum, - default=jnp.ones((self.num_vision,)), - include_self=False, - i_range=self.predator_vision_range, - )( - state.key, - (self.predator_params.view_angle, self.agent_radius), - state.predators, - state.predators, - pos=state.predators.pos, - n_view=self.num_vision, - i_range=self.predator_vision_range, - ) - - predator_obs = jnp.hstack([predator_obs_prey, predator_obs_predator]) - prey_obs = jnp.hstack([prey_obs_predators, prey_obs_prey]) - - return Observation( - predators=predator_obs, - prey=prey_obs, - ) - - @cached_property - def observation_spec(self) -> specs.Spec[Observation]: - """Returns the observation spec. - - Local predator and prey agent views representing - the distance to closest neighbours in the environment. - - Returns: - observation_spec: Predator-prey observation spec - """ - predators = specs.BoundedArray( - shape=(self.num_predators, 2 * self.num_vision), - minimum=0.0, - maximum=1.0, - dtype=float, - name="predators", - ) - prey = specs.BoundedArray( - shape=(self.num_prey, 2 * self.num_vision), - minimum=0.0, - maximum=1.0, - dtype=float, - name="prey", - ) - return specs.Spec( - Observation, - "ObservationSpec", - predators=predators, - prey=prey, - ) - - @cached_property - def action_spec(self) -> specs.Spec[PredatorPreyStruct]: - """Returns the action spec. - - Arrays of individual agent actions. Each agents action is - an array representing [rotation, acceleration] in the range - [-1, 1]. - - Returns: - action_spec: Predator-prey action spec - """ - predators = specs.BoundedArray( - shape=(self.num_predators, 2), - minimum=-1.0, - maximum=1.0, - dtype=float, - name="predators", - ) - prey = specs.BoundedArray( - shape=(self.num_prey, 2), - minimum=-1.0, - maximum=1.0, - dtype=float, - name="prey", - ) - return specs.Spec( - PredatorPreyStruct, - "ActionSpec", - predators=predators, - prey=prey, - ) - - @cached_property - def reward_spec(self) -> specs.Spec[PredatorPreyStruct]: # type: ignore[override] - """Returns the reward spec. - - Arrays of individual rewards for both predator and - prey types. - - Returns: - reward_spec: Predator-prey reward spec - """ - predators = specs.Array( - shape=(self.num_predators,), - dtype=float, - name="predators", - ) - prey = specs.Array( - shape=(self.num_prey,), - dtype=float, - name="prey", - ) - return specs.Spec( - PredatorPreyStruct, - "rewardsSpec", - predators=predators, - prey=prey, - ) - - def render(self, state: State) -> None: - """Render a frame of the environment for a given state using matplotlib. - - Args: - state: State object containing the current dynamics of the environment. - """ - self._viewer.render(state) - - def animate( - self, - states: Sequence[State], - interval: int = 200, - save_path: Optional[str] = None, - ) -> FuncAnimation: - """Create an animation from a sequence of environment states. - - Args: - states: sequence of environment states corresponding to consecutive - timesteps. - interval: delay between frames in milliseconds. - save_path: the path where the animation file should be saved. If it - is None, the plot will not be saved. - - Returns: - Animation that can be saved as a GIF, MP4, or rendered with HTML. - """ - return self._viewer.animate(states, interval=interval, save_path=save_path) - - def close(self) -> None: - """Perform any necessary cleanup.""" - self._viewer.close() diff --git a/jumanji/environments/swarms/predator_prey/env_test.py b/jumanji/environments/swarms/predator_prey/env_test.py deleted file mode 100644 index a75c9eaec..000000000 --- a/jumanji/environments/swarms/predator_prey/env_test.py +++ /dev/null @@ -1,374 +0,0 @@ -# Copyright 2022 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from typing import List, Tuple - -import chex -import jax -import jax.numpy as jnp -import matplotlib -import matplotlib.pyplot as plt -import py -import pytest - -from jumanji.environments.swarms.common.types import AgentState -from jumanji.environments.swarms.predator_prey import PredatorPrey -from jumanji.environments.swarms.predator_prey.rewards import ( - DistanceRewards, - SparseRewards, -) -from jumanji.environments.swarms.predator_prey.types import ( - Observation, - PredatorPreyStruct, - State, -) -from jumanji.testing.env_not_smoke import ( - check_env_does_not_smoke, - check_env_specs_does_not_smoke, -) -from jumanji.types import StepType, TimeStep - -PREDATOR_VISION_RANGE = 0.2 -PREY_VISION_RANGE = 0.1 -PREDATOR_REWARD = 0.2 -PREY_PENALTY = 0.1 -AGENT_RADIUS = 0.05 - - -@pytest.fixture -def env() -> PredatorPrey: - return PredatorPrey( - num_predators=2, - num_prey=10, - prey_vision_range=PREY_VISION_RANGE, - predator_vision_range=PREDATOR_VISION_RANGE, - num_vision=11, - agent_radius=AGENT_RADIUS, - sparse_rewards=True, - predator_max_rotate=0.1, - predator_max_accelerate=0.01, - predator_min_speed=0.01, - predator_max_speed=0.05, - predator_view_angle=0.5, - prey_max_rotate=0.1, - prey_max_accelerate=0.01, - prey_min_speed=0.01, - prey_max_speed=0.05, - prey_view_angle=0.5, - ) - - -def test_env_init(env: PredatorPrey) -> None: - """ - Check newly initialised state has expected array shapes - and initial timestep. - """ - k = jax.random.PRNGKey(101) - state, timestep = env.reset(k) - assert isinstance(state, State) - - assert isinstance(state.predators, AgentState) - assert state.predators.pos.shape == (env.num_predators, 2) - assert state.predators.speed.shape == (env.num_predators,) - assert state.predators.speed.shape == (env.num_predators,) - - assert isinstance(state.prey, AgentState) - assert state.prey.pos.shape == (env.num_prey, 2) - assert state.prey.speed.shape == (env.num_prey,) - assert state.prey.speed.shape == (env.num_prey,) - - assert isinstance(timestep.observation, Observation) - assert timestep.observation.predators.shape == ( - env.num_predators, - 2 * env.num_vision, - ) - assert timestep.observation.prey.shape == (env.num_prey, 2 * env.num_vision) - assert timestep.step_type == StepType.FIRST - - -@pytest.mark.parametrize("sparse_rewards", [True, False]) -def test_env_step(env: PredatorPrey, sparse_rewards: bool) -> None: - """ - Run several steps of the environment with random actions and - check states (i.e. positions, heading, speeds) all fall - inside expected ranges. - """ - env.sparse_rewards = sparse_rewards - key = jax.random.PRNGKey(101) - n_steps = 22 - - def step( - carry: Tuple[chex.PRNGKey, State], _: None - ) -> Tuple[Tuple[chex.PRNGKey, State], Tuple[State, TimeStep[Observation]]]: - k, state = carry - k, k_pred, k_prey = jax.random.split(k, num=3) - actions = PredatorPreyStruct( - predators=jax.random.uniform(k_pred, (env.num_predators, 2), minval=-1.0, maxval=1.0), - prey=jax.random.uniform(k_prey, (env.num_prey, 2), minval=-1.0, maxval=1.0), - ) - new_state, timestep = env.step(state, actions) - return (k, new_state), (state, timestep) - - init_state, _ = env.reset(key) - (_, final_state), (state_history, timesteps) = jax.lax.scan( - step, (key, init_state), length=n_steps - ) - - assert isinstance(state_history, State) - - assert state_history.predators.pos.shape == (n_steps, env.num_predators, 2) - assert jnp.all((0.0 <= state_history.predators.pos) & (state_history.predators.pos <= 1.0)) - assert state_history.predators.speed.shape == (n_steps, env.num_predators) - assert jnp.all( - (env.predator_params.min_speed <= state_history.predators.speed) - & (state_history.predators.speed <= env.predator_params.max_speed) - ) - assert state_history.predators.speed.shape == (n_steps, env.num_predators) - assert jnp.all( - (0.0 <= state_history.predators.heading) & (state_history.predators.heading <= 2.0 * jnp.pi) - ) - - assert state_history.prey.pos.shape == (n_steps, env.num_prey, 2) - assert jnp.all((0.0 <= state_history.prey.pos) & (state_history.prey.pos <= 1.0)) - assert state_history.prey.speed.shape == (n_steps, env.num_prey) - assert jnp.all( - (env.predator_params.min_speed <= state_history.prey.speed) - & (state_history.prey.speed <= env.predator_params.max_speed) - ) - assert state_history.prey.heading.shape == (n_steps, env.num_prey) - assert jnp.all( - (0.0 <= state_history.prey.heading) & (state_history.prey.heading <= 2.0 * jnp.pi) - ) - - -@pytest.mark.parametrize("sparse_rewards", [True, False]) -def test_env_does_not_smoke(env: PredatorPrey, sparse_rewards: bool) -> None: - """Test that we can run an episode without any errors.""" - env.sparse_rewards = sparse_rewards - env.max_steps = 10 - - def select_action(action_key: chex.PRNGKey, _state: Observation) -> PredatorPreyStruct: - predator_key, prey_key = jax.random.split(action_key) - return PredatorPreyStruct( - predators=jax.random.uniform( - predator_key, (env.num_predators, 2), minval=-1.0, maxval=1.0 - ), - prey=jax.random.uniform(prey_key, (env.num_prey, 2), minval=-1.0, maxval=1.0), - ) - - check_env_does_not_smoke(env, select_action=select_action) - - -def test_env_specs_do_not_smoke(env: PredatorPrey) -> None: - """Test that we can access specs without any errors.""" - check_env_specs_does_not_smoke(env) - - -@pytest.mark.parametrize( - "predator_pos, predator_heading, predator_view, prey_pos, prey_heading, prey_view", - [ - # Both out of view range - ([[0.8, 0.5]], [jnp.pi], [(0, 0, 1.0)], [[0.2, 0.5]], [0.0], [(0, 0, 1.0)]), - # In predator range but not prey - ([[0.35, 0.5]], [jnp.pi], [(0, 5, 0.75)], [[0.2, 0.5]], [0.0], [(0, 0, 1.0)]), - # Both view each other - ([[0.25, 0.5]], [jnp.pi], [(0, 5, 0.25)], [[0.2, 0.5]], [0.0], [(0, 5, 0.5)]), - # Prey facing wrong direction - ( - [[0.25, 0.5]], - [jnp.pi], - [(0, 5, 0.25)], - [[0.2, 0.5]], - [jnp.pi], - [(0, 0, 1.0)], - ), - # Prey sees closest predator - ( - [[0.35, 0.5], [0.25, 0.5]], - [jnp.pi, jnp.pi], - [(0, 5, 0.75), (0, 16, 0.5), (1, 5, 0.25)], - [[0.2, 0.5]], - [0.0], - [(0, 5, 0.5)], - ), - # Observed around wrapped edge - ( - [[0.025, 0.5]], - [jnp.pi], - [(0, 5, 0.25)], - [[0.975, 0.5]], - [0.0], - [(0, 5, 0.5)], - ), - ], -) -def test_view_observations( - env: PredatorPrey, - predator_pos: List[List[float]], - predator_heading: List[float], - predator_view: List[Tuple[int, int, float]], - prey_pos: List[List[float]], - prey_heading: List[float], - prey_view: List[Tuple[int, int, float]], -) -> None: - """ - Test view model generates expected array with different - configurations of agents. - """ - - predator_pos = jnp.array(predator_pos) - predator_heading = jnp.array(predator_heading) - predator_speed = jnp.zeros(predator_heading.shape) - - prey_pos = jnp.array(prey_pos) - prey_heading = jnp.array(prey_heading) - prey_speed = jnp.zeros(prey_heading.shape) - - state = State( - predators=AgentState(pos=predator_pos, heading=predator_heading, speed=predator_speed), - prey=AgentState(pos=prey_pos, heading=prey_heading, speed=prey_speed), - key=jax.random.PRNGKey(101), - ) - - obs = env._state_to_observation(state) - - assert isinstance(obs, Observation) - - predator_expected = jnp.ones( - ( - predator_heading.shape[0], - 2 * env.num_vision, - ) - ) - for i, idx, val in predator_view: - predator_expected = predator_expected.at[i, idx].set(val) - - assert jnp.all(jnp.isclose(obs.predators, predator_expected)) - - prey_expected = jnp.ones( - ( - prey_heading.shape[0], - 2 * env.num_vision, - ) - ) - for i, idx, val in prey_view: - prey_expected = prey_expected.at[i, idx].set(val) - - assert jnp.all(jnp.isclose(obs.prey[0], prey_expected)) - - -@pytest.mark.parametrize( - "predator_pos, predator_reward, prey_pos, prey_reward", - [ - ([0.5, 0.5], 0.0, [0.8, 0.5], 0.0), - ([0.5, 0.5], PREDATOR_REWARD, [0.5999, 0.5], -PREY_PENALTY), - ([0.5, 0.5], PREDATOR_REWARD, [0.5001, 0.5], -PREY_PENALTY), - ], -) -def test_sparse_rewards( - predator_pos: List[float], - predator_reward: float, - prey_pos: List[float], - prey_reward: float, -) -> None: - """ - Test sparse rewards are correctly assigned. - """ - - state = State( - predators=AgentState( - pos=jnp.array([predator_pos]), - heading=jnp.zeros((1,)), - speed=jnp.zeros((1,)), - ), - prey=AgentState( - pos=jnp.array([prey_pos]), - heading=jnp.zeros((1,)), - speed=jnp.zeros((1,)), - ), - key=jax.random.PRNGKey(101), - ) - - reward_fn = SparseRewards(AGENT_RADIUS, PREDATOR_REWARD, PREY_PENALTY) - rewards = reward_fn(state) - - assert isinstance(rewards, PredatorPreyStruct) - assert rewards.predators[0] == predator_reward - assert rewards.prey[0] == prey_reward - - -@pytest.mark.parametrize( - "predator_pos, predator_reward, prey_pos, prey_reward", - [ - ([0.5, 0.5], 0.0, [0.8, 0.5], 0.0), - ([0.5, 0.5], 0.5 * PREDATOR_REWARD, [0.55, 0.5], -0.5 * PREY_PENALTY), - ([0.5, 0.5], PREDATOR_REWARD, [0.5 + 1e-10, 0.5], -PREY_PENALTY), - ], -) -def test_distance_rewards( - predator_pos: List[float], - predator_reward: float, - prey_pos: List[float], - prey_reward: float, -) -> None: - """ - Test rewards scaled with distance are correctly assigned. - """ - - state = State( - predators=AgentState( - pos=jnp.array([predator_pos]), - heading=jnp.zeros((1,)), - speed=jnp.zeros((1,)), - ), - prey=AgentState( - pos=jnp.array([prey_pos]), - heading=jnp.zeros((1,)), - speed=jnp.zeros((1,)), - ), - key=jax.random.PRNGKey(101), - ) - - reward_fn = DistanceRewards( - PREDATOR_VISION_RANGE, PREY_VISION_RANGE, PREDATOR_REWARD, PREY_PENALTY - ) - rewards = reward_fn(state) - assert isinstance(rewards, PredatorPreyStruct) - assert jnp.isclose(rewards.predators[0], predator_reward) - assert jnp.isclose(rewards.prey[0], prey_reward) - - -def test_predator_prey_render(monkeypatch: pytest.MonkeyPatch, env: PredatorPrey) -> None: - """Check that the render method builds the figure but does not display it.""" - monkeypatch.setattr(plt, "show", lambda fig: None) - step_fn = jax.jit(env.step) - state, timestep = env.reset(jax.random.PRNGKey(0)) - action = env.action_spec.generate_value() - state, timestep = step_fn(state, action) - env.render(state) - env.close() - - -def test_snake__animation(env: PredatorPrey, tmpdir: py.path.local) -> None: - """Check that the animation method creates the animation correctly and can save to a gif.""" - step_fn = jax.jit(env.step) - state, _ = env.reset(jax.random.PRNGKey(0)) - states = [state] - action = env.action_spec.generate_value() - state, _ = step_fn(state, action) - states.append(state) - animation = env.animate(states, interval=200, save_path=None) - assert isinstance(animation, matplotlib.animation.Animation) - - path = str(tmpdir.join("/anim.gif")) - animation.save(path, writer=matplotlib.animation.PillowWriter(fps=10), dpi=60) diff --git a/jumanji/environments/swarms/predator_prey/rewards.py b/jumanji/environments/swarms/predator_prey/rewards.py deleted file mode 100644 index bb489df79..000000000 --- a/jumanji/environments/swarms/predator_prey/rewards.py +++ /dev/null @@ -1,273 +0,0 @@ -# Copyright 2022 InstaDeep Ltd. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc -from typing import Union - -import chex -import jax.numpy as jnp -from esquilax.transforms import nearest_neighbour, spatial -from esquilax.utils import shortest_distance - -from jumanji.environments.swarms.predator_prey.types import PredatorPreyStruct, State - - -class RewardFn(abc.ABC): - """Abstract class for `PredatorPrey` rewards.""" - - @abc.abstractmethod - def __call__(self, state: State) -> PredatorPreyStruct: - """The reward function used in the `PredatorPrey` environment. - - Args: - state: `PredatorPrey` state. - - Returns: - The reward for the current step for individual agents. - """ - - -class SparseRewards(RewardFn): - """Sparse rewards applied when agents come into contact. - - Rewards applied when predators and prey come into contact - (i.e. overlap), positively rewarding predators and negatively - penalising prey. Attempts to model predators `capturing` prey. - """ - - def __init__(self, agent_radius: float, predator_reward: float, prey_penalty: float) -> None: - """ - Initialise a sparse reward function. - - Args: - agent_radius: Radius of simulated agents. - predator_reward: Predator reward value. - prey_penalty: Prey penalty (this is negated when applied). - """ - self.agent_radius = agent_radius - self.prey_penalty = prey_penalty - self.predator_reward = predator_reward - - def prey_rewards( - self, - _key: chex.PRNGKey, - _params: None, - _prey: None, - _predator: None, - ) -> float: - """Penalise a prey agent if contacted by a predator agent. - - Apply a negative penalty to prey agents that collide - with a prey agent. This function is applied using an - Esquilax spatial interaction which accumulates rewards. - - Args: - _key: Dummy JAX random key . - _params: Dummy params (required by Esquilax). - _prey: Dummy agent-state (required by Esquilax). - _predator: Dummy agent-state (required by Esquilax). - - Returns: - float: Negative penalty applied to prey agent. - """ - return -self.prey_penalty - - def predator_rewards( - self, - _key: chex.PRNGKey, - _params: None, - _predator: None, - _prey: None, - ) -> float: - """Reward a predator agent if it is within range of a prey agent - (required by Esquilax) - Apply a fixed positive reward if a predator agent is within - a fixed range of a prey-agent. This function can - be used with an Esquilax spatial interaction to - apply rewards to agents in range. - - Args: - _key: Dummy JAX random key (required by Esquilax). - _params: Dummy params (required by Esquilax). - _prey: Dummy agent-state (required by Esquilax). - _predator: Dummy agent-state (required by Esquilax). - - Returns: - float: Predator agent reward. - """ - return self.predator_reward - - def __call__(self, state: State) -> PredatorPreyStruct: - prey_rewards = spatial( - self.prey_rewards, - reduction=jnp.add, - default=0.0, - include_self=False, - i_range=2 * self.agent_radius, - )( - state.key, - None, - None, - None, - pos=state.prey.pos, - pos_b=state.predators.pos, - ) - predator_rewards = nearest_neighbour( - self.predator_rewards, - default=0.0, - i_range=2 * self.agent_radius, - )( - state.key, - None, - None, - None, - pos=state.predators.pos, - pos_b=state.prey.pos, - ) - return PredatorPreyStruct( - predators=predator_rewards, - prey=prey_rewards, - ) - - -class DistanceRewards(RewardFn): - """Rewards based on proximity to other agents. - - Rewards generated based on an agents proximity to other - agents within their vision range. Predator rewards increase - as they get closer to prey, and prey rewards become - increasingly negative as they get closer to predators. - Rewards are summed over all other agents within range of - an agent. - """ - - def __init__( - self, - predator_vision_range: float, - prey_vision_range: float, - predator_reward: float, - prey_penalty: float, - ) -> None: - """ - Initialise a distance reward function - - Args: - predator_vision_range: Predator agent vision range. - prey_vision_range: Prey agent vision range. - predator_reward: Max reward value applied to - predator agents. - prey_penalty: Max reward value applied to prey agents - (this value is negated when applied). - """ - self.predator_vision_range = predator_vision_range - self.prey_vision_range = prey_vision_range - self.prey_penalty = prey_penalty - self.predator_reward = predator_reward - - def prey_rewards( - self, - _key: chex.PRNGKey, - _params: None, - prey_pos: chex.Array, - predator_pos: chex.Array, - *, - i_range: float, - ) -> Union[float, chex.Array]: - """Penalise a prey agent based on distance from a predator agent. - - Apply a negative penalty based on a distance between - agents. The penalty is a linear function of distance, - 0 at max distance up to `-penalty` at 0 distance. This function - can be used with an Esquilax spatial interaction to accumulate - rewards between agents. - - Args: - _key: Dummy JAX random key (required by Esquilax). - _params: Dummy params (required by Esquilax). - prey_pos: Prey positions. - predator_pos: Predator positions. - i_range: Static interaction range. - - Returns: - float: Agent rewards. - """ - d = shortest_distance(prey_pos, predator_pos) / i_range - return self.prey_penalty * (d - 1.0) - - def predator_rewards( - self, - _key: chex.PRNGKey, - _params: None, - predator_pos: chex.Array, - prey_pos: chex.Array, - *, - i_range: float, - ) -> Union[float, chex.Array]: - """Reward a predator agent based on distance from a prey agent. - - Apply a positive reward based on the linear distance between - a predator and prey agent. Rewards are zero at the max - interaction distance, and maximal at 0 range. This function - can be used with an Esquilax spatial interaction to accumulate - rewards between agents. - - Args: - _key: Dummy JAX random key (required by Esquilax). - _params: Dummy parameters (required by Esquilax). - predator_pos: Predator position. - prey_pos: Prey position. - i_range: Static interaction range. - - Returns: - float: Predator agent reward. - """ - d = shortest_distance(predator_pos, prey_pos) / i_range - return self.predator_reward * (1.0 - d) - - def __call__(self, state: State) -> PredatorPreyStruct: - prey_rewards = spatial( - self.prey_rewards, - reduction=jnp.add, - default=0.0, - include_self=False, - i_range=self.prey_vision_range, - )( - state.key, - None, - state.prey.pos, - state.predators.pos, - pos=state.prey.pos, - pos_b=state.predators.pos, - i_range=self.prey_vision_range, - ) - predator_rewards = spatial( - self.predator_rewards, - reduction=jnp.add, - default=0.0, - include_self=False, - i_range=self.predator_vision_range, - )( - state.key, - None, - state.predators.pos, - state.prey.pos, - pos=state.predators.pos, - pos_b=state.prey.pos, - i_range=self.prey_vision_range, - ) - - return PredatorPreyStruct( - predators=predator_rewards, - prey=prey_rewards, - ) diff --git a/jumanji/environments/swarms/predator_prey/__init__.py b/jumanji/environments/swarms/search_and_rescue/__init__.py similarity index 94% rename from jumanji/environments/swarms/predator_prey/__init__.py rename to jumanji/environments/swarms/search_and_rescue/__init__.py index 4fef030a7..f4959771a 100644 --- a/jumanji/environments/swarms/predator_prey/__init__.py +++ b/jumanji/environments/swarms/search_and_rescue/__init__.py @@ -11,4 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from .env import PredatorPrey +from .env import SearchAndRescue diff --git a/jumanji/environments/swarms/search_and_rescue/dynamics.py b/jumanji/environments/swarms/search_and_rescue/dynamics.py new file mode 100644 index 000000000..d9abe5b19 --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/dynamics.py @@ -0,0 +1,61 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +import chex +import jax + + +class TargetDynamics(abc.ABC): + @abc.abstractmethod + def __call__(self, key: chex.PRNGKey, target_pos: chex.Array) -> chex.Array: + """Interface for target position update function. + + Args: + key: random key. + target_pos: Current target positions. + + Returns: + Updated target positions. + """ + + +class RandomWalk(TargetDynamics): + def __init__(self, step_size: float): + """ + Random walk target dynamics. + + Target positions are updated with random + steps, sampled uniformly from the range + [-step-size, step-size]. + + Args: + step_size: Maximum random step-size + """ + self.step_size = step_size + + def __call__(self, key: chex.PRNGKey, target_pos: chex.Array) -> chex.Array: + """Update target positions. + + Args: + key: random key. + target_pos: Current target positions. + + Returns: + Updated target positions. + """ + d_pos = jax.random.uniform(key, target_pos.shape) + d_pos = self.step_size * 2.0 * (d_pos - 0.5) + return target_pos + d_pos diff --git a/jumanji/environments/swarms/search_and_rescue/env.py b/jumanji/environments/swarms/search_and_rescue/env.py new file mode 100644 index 000000000..a1fa15498 --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/env.py @@ -0,0 +1,429 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import cached_property +from typing import Optional, Sequence, Tuple + +import chex +import jax +import jax.numpy as jnp +from esquilax.transforms import spatial +from esquilax.utils import shortest_vector +from matplotlib.animation import FuncAnimation + +from jumanji import specs +from jumanji.env import Environment +from jumanji.environments.swarms.common.types import AgentParams, AgentState +from jumanji.environments.swarms.common.updates import update_state, view +from jumanji.environments.swarms.search_and_rescue.dynamics import RandomWalk, TargetDynamics +from jumanji.environments.swarms.search_and_rescue.generator import Generator, RandomGenerator +from jumanji.environments.swarms.search_and_rescue.types import Observation, State, TargetState +from jumanji.environments.swarms.search_and_rescue.viewer import SearchAndRescueViewer +from jumanji.types import TimeStep, restart, termination, transition +from jumanji.viewer import Viewer + + +def _inner_found_check( + _key: chex.PRNGKey, + searcher_view_angle: float, + target_pos: chex.Array, + searcher: AgentState, +) -> chex.Array: + """ + Return True if searcher can view the target. + """ + dx = shortest_vector(searcher.pos, target_pos) + phi = jnp.arctan2(dx[1], dx[0]) % (2 * jnp.pi) + dh = shortest_vector(phi, searcher.heading, 2 * jnp.pi) + return (dh >= -searcher_view_angle) & (dh <= searcher_view_angle) + + +def _inner_reward_check( + _key: chex.PRNGKey, + searcher_view_angle: float, + searcher: AgentState, + target: TargetState, +) -> chex.Array: + """ + Return +1.0 reward if the target is within the searcher view angle. + """ + dx = shortest_vector(searcher.pos, target.pos) + phi = jnp.arctan2(dx[1], dx[0]) % (2 * jnp.pi) + dh = shortest_vector(phi, searcher.heading, 2 * jnp.pi) + can_see = (dh >= -searcher_view_angle) & (dh <= searcher_view_angle) + return jax.lax.cond( + # (not target.found) & can_see, + can_see, + lambda: 1.0, + lambda: 0.0, + ) + + +class SearchAndRescue(Environment): + """A multi-agent search environment + + Environment modelling a collection of agents collectively searching + for a set of targets on a 2d environment. Agents are rewarded + (individually) for coming within a fixed range of a target that has + not already been detected. Agents visualise their local environment + (i.e. the location of other agents) via a simple segmented view model. + The environment consists of a uniform space with wrapped boundaries. + + - observation: `Observation` + searcher_views: jax array (float) of shape (num_searchers, num_vision) + individual local views of positions of other searching agents. + target_remaining: (float) Number of targets remaining to be found from + the total scaled to the range [0, 1] (i.e. a value of 1.0 indicates + all the targets are still to be found). + time_remaining: (float) Steps remaining to find agents, scaled to the + range [0,1] (i.e. the value is 0 when time runs out). + + - action: jax array (float) of shape (num_searchers, 2) + Array of individual agent actions. Each agents actions rotate and + accelerate/decelerate the agent as [rotation, acceleration] on the range + [-1, 1]. These values are then scaled to update agent velocities within + given parameters. + + - reward: jax array (float) of shape (num_searchers,) + Arrays of individual agent rewards. Rewards are granted when an agent + comes into contact range with a target that has not yet been found, and + that agent is within the searchers view cone. + + - state: `State` + - searchers: `AgentState` + - pos: jax array (float) of shape (num_searchers, 2) in the range [0, 1]. + - heading: jax array (float) of shape (num_searcher,) in + the range [0, 2pi]. + - speed: jax array (float) of shape (num_searchers,) in the + range [min_speed, max_speed]. + - targets: `TargetState` + - pos: jax array (float) of shape (num_targets, 2) in the range [0, 1]. + - found: jax array (bool) of shape (num_targets,) flag indicating if + target has been located by an agent. + - key: jax array (uint32) of shape (2,) + - step: int representing the current simulation step. + + + ```python + from jumanji.environments import SearchAndRescue + env = SearchAndRescue( + num_searchers=10, + num_targets=20, + searcher_vision_range=0.1, + target_contact_range=0.01, + num_vision=40, + agent_radius0.01, + searcher_max_rotate=0.1, + searcher_max_accelerate=0.01, + searcher_min_speed=0.01, + searcher_max_speed=0.05, + searcher_view_angle=0.5, + ) + key = jax.random.PRNGKey(0) + state, timestep = jax.jit(env.reset)(key) + env.render(state) + action = env.action_spec.generate_value() + state, timestep = jax.jit(env.step)(state, action) + env.render(state) + ``` + """ + + def __init__( + self, + num_searchers: int, + num_targets: int, + searcher_vision_range: float, + target_contact_range: float, + num_vision: int, + agent_radius: float, + searcher_max_rotate: float, + searcher_max_accelerate: float, + searcher_min_speed: float, + searcher_max_speed: float, + searcher_view_angle: float, + max_steps: int = 400, + viewer: Optional[Viewer[State]] = None, + target_dynamics: Optional[TargetDynamics] = None, + generator: Optional[Generator] = None, + ) -> None: + """Instantiates a `SearchAndRescue` environment + + Note: + The environment is square with dimensions + `[1.0, 1.0]` so parameters should be scaled + appropriately. Also note that performance is + dependent on agent vision and interaction ranges, + where larger values can lead to large number of + agent interactions. + + Args: + num_searchers: Number of searching agents. + num_targets: Number of search targets. + searcher_vision_range: Search agent vision range. + target_contact_range: Range at which a searcher can 'find' a target. + num_vision: Number of cells/subdivisions in agent + view models. Larger numbers provide a more accurate + view, at the cost of the environment, at the cost + of performance and memory usage. + agent_radius: Radius of individual agents. This + effects how large they appear to other agents. + searcher_max_rotate: Maximum rotation searcher agents can + turn within a step. Should be a value from [0,1] + representing a fraction of pi radians. + searcher_max_accelerate: Maximum acceleration/deceleration + a searcher agent can apply within a step. + searcher_min_speed: Minimum speed a searcher agent can move at. + searcher_max_speed: Maximum speed a searcher agent can move at. + searcher_view_angle: Predator agent local view angle. Should be + a value from [0,1] representing a fraction of pi radians. + The view cone of an agent goes from +- of the view angle + relative to its heading. + max_steps: Maximum number of environment steps allowed for search. + viewer: `Viewer` used for rendering. Defaults to `SearchAndRescueViewer`. + target_dynamics: + target_dynamics: Target object dynamics model, implemented as a + `TargetDynamics` interface. Defaults to `RandomWalk`. + generator: Initial state `Generator` instance. Defaults to `RandomGenerator`. + """ + self.num_searchers = num_searchers + self.num_targets = num_targets + self.searcher_vision_range = searcher_vision_range + self.target_contact_range = target_contact_range + self.num_vision = num_vision + self.agent_radius = agent_radius + self.searcher_params = AgentParams( + max_rotate=searcher_max_rotate, + max_accelerate=searcher_max_accelerate, + min_speed=searcher_min_speed, + max_speed=searcher_max_speed, + view_angle=searcher_view_angle, + ) + self.max_steps = max_steps + super().__init__() + self._viewer = viewer or SearchAndRescueViewer() + self._target_dynamics = target_dynamics or RandomWalk(0.01) + self._generator = generator or RandomGenerator(num_searchers, num_targets) + + def __repr__(self) -> str: + return "\n".join( + [ + "Search & rescue multi-agent environment:", + f" - num searchers: {self.num_searchers}", + f" - num targets: {self.num_targets}", + f" - search vision range: {self.searcher_vision_range}", + f" - target contact range: {self.target_contact_range}", + f" - num vision: {self.num_vision}", + f" - agent radius: {self.agent_radius}", + f" - max steps: {self.max_steps}," + f" - target dynamics: {self._target_dynamics.__class__.__name__}", + f" - generator: {self._generator.__class__.__name__}", + ] + ) + + def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: + """Initialise searcher positions and velocities, and target positions. + + Args: + key: Random key used to reset the environment. + + Returns: + state: Initial environment state. + timestep: TimeStep with individual search agent views. + """ + state = self._generator(key, self.searcher_params) + timestep = restart(observation=self._state_to_observation(state)) + return state, timestep + + def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Observation]]: + """Environment update. + + Update searcher velocities and consequently their positions, + mark found targets, and generate rewards and local observations. + + Args: + state: Environment state. + actions: Arrays of searcher steering actions. + + Returns: + state: Updated searcher and target positions and velocities. + timestep: Transition timestep with individual agent local observations. + """ + searchers = update_state(state.key, self.searcher_params, state.searchers, actions) + key, target_key = jax.random.split(state.key, num=2) + # Ensure target positions are wrapped + target_pos = self._target_dynamics(target_key, state.targets.pos) % 1.0 + # Grant searchers rewards if in range and not already detected + rewards = spatial( + _inner_reward_check, + reduction=jnp.add, + default=0.0, + i_range=self.target_contact_range, + )( + key, + self.searcher_params.view_angle, + searchers, + state.targets, + pos=searchers.pos, + pos_b=target_pos, + ) + # Mark targets as found if with contact range and view angle of a searcher + targets_found = spatial( + _inner_found_check, + reduction=jnp.logical_or, + default=False, + i_range=self.target_contact_range, + )( + key, + self.searcher_params.view_angle, + state.targets.pos, + searchers, + pos=target_pos, + pos_b=searchers.pos, + ) + # Targets need to remain found if they already have been + targets_found = jnp.logical_or(targets_found, state.targets.found) + state = State( + searchers=searchers, + targets=TargetState(pos=target_pos, found=targets_found), + key=key, + step=state.step + 1, + ) + observation = self._state_to_observation(state) + timestep = jax.lax.cond( + state.step >= self.max_steps, + termination, + transition, + rewards, + observation, + ) + return state, timestep + + def _state_to_observation(self, state: State) -> Observation: + searcher_views = spatial( + view, + reduction=jnp.minimum, + default=jnp.ones((self.num_vision,)), + include_self=False, + i_range=self.searcher_vision_range, + )( + state.key, + (self.searcher_params.view_angle, self.agent_radius), + state.searchers, + state.searchers, + pos=state.searchers.pos, + n_view=self.num_vision, + i_range=self.searcher_vision_range, + ) + + return Observation( + searcher_views=searcher_views, + target_remaining=1.0 - jnp.sum(state.targets.found) / self.num_targets, + time_remaining=1.0 - state.step / (self.max_steps + 1), + ) + + @cached_property + def observation_spec(self) -> specs.Spec[Observation]: + """Returns the observation spec. + + Local searcher agent views representing + the distance to the closest neighbouring agents in the + environment. + + Returns: + observation_spec: Search-and-rescue observation spec + """ + searcher_views = specs.BoundedArray( + shape=(self.num_searchers, self.num_vision), + minimum=0.0, + maximum=1.0, + dtype=float, + name="searcher_views", + ) + return specs.Spec( + Observation, + "ObservationSpec", + searcher_views=searcher_views, + target_remaining=specs.BoundedArray( + shape=(), minimum=0.0, maximum=1.0, name="target_remaining", dtype=float + ), + time_remaining=specs.BoundedArray( + shape=(), minimum=0.0, maximum=1.0, name="time_remaining", dtype=float + ), + ) + + @cached_property + def action_spec(self) -> specs.BoundedArray: + """Returns the action spec. + + 2d array of individual agent actions. Each agents action is + an array representing [rotation, acceleration] in the range + [-1, 1]. + + Returns: + action_spec: Action array spec + """ + return specs.BoundedArray( + shape=(self.num_searchers, 2), + minimum=-1.0, + maximum=1.0, + dtype=float, + ) + + @cached_property + def reward_spec(self) -> specs.BoundedArray: + """Returns the reward spec. + + Array of individual rewards for each agent. + + Returns: + reward_spec: Reward array spec. + """ + return specs.BoundedArray( + shape=(self.num_searchers,), + minimum=0.0, + maximum=float(self.num_targets), + dtype=float, + ) + + def render(self, state: State) -> None: + """Render a frame of the environment for a given state using matplotlib. + + Args: + state: State object containing the current dynamics of the environment. + """ + self._viewer.render(state) + + def animate( + self, + states: Sequence[State], + interval: int = 200, + save_path: Optional[str] = None, + ) -> FuncAnimation: + """Create an animation from a sequence of environment states. + + Args: + states: sequence of environment states corresponding to consecutive + timesteps. + interval: delay between frames in milliseconds. + save_path: the path where the animation file should be saved. If it + is None, the plot will not be saved. + + Returns: + Animation that can be saved as a GIF, MP4, or rendered with HTML. + """ + return self._viewer.animate(states, interval=interval, save_path=save_path) + + def close(self) -> None: + """Perform any necessary cleanup.""" + self._viewer.close() diff --git a/jumanji/environments/swarms/search_and_rescue/env_test.py b/jumanji/environments/swarms/search_and_rescue/env_test.py new file mode 100644 index 000000000..8f0de699f --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/env_test.py @@ -0,0 +1,270 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Tuple + +import chex +import jax +import jax.numpy as jnp +import matplotlib +import matplotlib.pyplot as plt +import py +import pytest + +from jumanji.environments.swarms.common.types import AgentState +from jumanji.environments.swarms.search_and_rescue import SearchAndRescue +from jumanji.environments.swarms.search_and_rescue.dynamics import RandomWalk +from jumanji.environments.swarms.search_and_rescue.types import ( + Observation, + State, + TargetState, +) +from jumanji.testing.env_not_smoke import check_env_does_not_smoke, check_env_specs_does_not_smoke +from jumanji.types import StepType, TimeStep + +SEARCHER_VISION_RANGE = 0.2 +TARGET_CONTACT_RANGE = 0.05 +AGENT_RADIUS = 0.05 + + +@pytest.fixture +def env() -> SearchAndRescue: + return SearchAndRescue( + num_searchers=10, + num_targets=20, + searcher_vision_range=SEARCHER_VISION_RANGE, + target_contact_range=TARGET_CONTACT_RANGE, + num_vision=11, + agent_radius=AGENT_RADIUS, + searcher_max_rotate=0.2, + searcher_max_accelerate=0.01, + searcher_min_speed=0.01, + searcher_max_speed=0.05, + searcher_view_angle=0.5, + max_steps=25, + ) + + +def test_env_init(env: SearchAndRescue) -> None: + """ + Check newly initialised state has expected array shapes + and initial timestep. + """ + k = jax.random.PRNGKey(101) + state, timestep = env.reset(k) + assert isinstance(state, State) + + assert isinstance(state.searchers, AgentState) + assert state.searchers.pos.shape == (env.num_searchers, 2) + assert state.searchers.speed.shape == (env.num_searchers,) + assert state.searchers.speed.shape == (env.num_searchers,) + + assert isinstance(state.targets, TargetState) + assert state.targets.pos.shape == (env.num_targets, 2) + assert state.targets.found.shape == (env.num_targets,) + assert jnp.array_equal(state.targets.found, jnp.full((env.num_targets,), False, dtype=bool)) + assert state.step == 0 + + assert isinstance(timestep.observation, Observation) + assert timestep.observation.searcher_views.shape == ( + env.num_searchers, + env.num_vision, + ) + assert timestep.step_type == StepType.FIRST + + +def test_env_step(env: SearchAndRescue) -> None: + """ + Run several steps of the environment with random actions and + check states (i.e. positions, heading, speeds) all fall + inside expected ranges. + """ + key = jax.random.PRNGKey(101) + n_steps = 22 + + def step( + carry: Tuple[chex.PRNGKey, State], _: None + ) -> Tuple[Tuple[chex.PRNGKey, State], Tuple[State, TimeStep[Observation]]]: + k, state = carry + k, k_search = jax.random.split(k) + actions = jax.random.uniform(k_search, (env.num_searchers, 2), minval=-1.0, maxval=1.0) + new_state, timestep = env.step(state, actions) + return (k, new_state), (state, timestep) + + init_state, _ = env.reset(key) + (_, final_state), (state_history, timesteps) = jax.lax.scan( + step, (key, init_state), length=n_steps + ) + + assert isinstance(state_history, State) + + assert state_history.searchers.pos.shape == (n_steps, env.num_searchers, 2) + assert jnp.all((0.0 <= state_history.searchers.pos) & (state_history.searchers.pos <= 1.0)) + assert state_history.searchers.speed.shape == (n_steps, env.num_searchers) + assert jnp.all( + (env.searcher_params.min_speed <= state_history.searchers.speed) + & (state_history.searchers.speed <= env.searcher_params.max_speed) + ) + assert state_history.searchers.speed.shape == (n_steps, env.num_searchers) + assert jnp.all( + (0.0 <= state_history.searchers.heading) & (state_history.searchers.heading <= 2.0 * jnp.pi) + ) + + assert state_history.targets.pos.shape == (n_steps, env.num_targets, 2) + assert jnp.all((0.0 <= state_history.targets.pos) & (state_history.targets.pos <= 1.0)) + + +def test_env_does_not_smoke(env: SearchAndRescue) -> None: + """Test that we can run an episode without any errors.""" + env.max_steps = 10 + + def select_action(action_key: chex.PRNGKey, _state: Observation) -> chex.Array: + return jax.random.uniform(action_key, (env.num_searchers, 2), minval=-1.0, maxval=1.0) + + check_env_does_not_smoke(env, select_action=select_action) + + +def test_env_specs_do_not_smoke(env: SearchAndRescue) -> None: + """Test that we can access specs without any errors.""" + check_env_specs_does_not_smoke(env) + + +@pytest.mark.parametrize( + "searcher_positions, searcher_headings, view_updates", + [ + # Both out of view range + ([[0.8, 0.5], [0.2, 0.5]], [jnp.pi, 0.0], []), + # Both view each other + ([[0.25, 0.5], [0.2, 0.5]], [jnp.pi, 0.0], [(0, 5, 0.25), (1, 5, 0.25)]), + # One facing wrong direction + ( + [[0.25, 0.5], [0.2, 0.5]], + [jnp.pi, jnp.pi], + [(0, 5, 0.25)], + ), + # Only see closest neighbour + ( + [[0.35, 0.5], [0.25, 0.5], [0.2, 0.5]], + [jnp.pi, 0.0, 0.0], + [(0, 5, 0.5), (1, 5, 0.5), (2, 5, 0.25)], + ), + # Observed around wrapped edge + ( + [[0.025, 0.5], [0.975, 0.5]], + [jnp.pi, 0.0], + [(0, 5, 0.25), (1, 5, 0.25)], + ), + ], +) +def test_searcher_view( + env: SearchAndRescue, + searcher_positions: List[List[float]], + searcher_headings: List[float], + view_updates: List[Tuple[int, int, float]], +) -> None: + """ + Test view model generates expected array with different + configurations of agents. + """ + + searcher_positions = jnp.array(searcher_positions) + searcher_headings = jnp.array(searcher_headings) + searcher_speed = jnp.zeros(searcher_headings.shape) + + state = State( + searchers=AgentState( + pos=searcher_positions, heading=searcher_headings, speed=searcher_speed + ), + targets=TargetState(pos=jnp.zeros((1, 2)), found=jnp.zeros((1, 2), dtype=bool)), + key=jax.random.PRNGKey(101), + ) + + obs = env._state_to_observation(state) + + assert isinstance(obs, Observation) + + expected = jnp.ones((searcher_headings.shape[0], env.num_vision)) + + for i, idx, val in view_updates: + expected = expected.at[i, idx].set(val) + + assert jnp.all(jnp.isclose(obs.searcher_views, expected)) + + +def test_target_detection(env: SearchAndRescue) -> None: + # Keep targets in one location + env._target_dynamics = RandomWalk(step_size=0.0) + + # Agent facing wrong direction should not see target + state = State( + searchers=AgentState( + pos=jnp.array([[0.5, 0.5]]), heading=jnp.array([jnp.pi]), speed=jnp.array([0.0]) + ), + targets=TargetState(pos=jnp.array([[0.54, 0.5]]), found=jnp.array([False])), + key=jax.random.PRNGKey(101), + ) + state, timestep = env.step(state, jnp.zeros((1, 2))) + assert not state.targets.found[0] + assert timestep.reward[0] == 0 + + # Rotated agent should detect target + state = State( + searchers=AgentState( + pos=state.searchers.pos, heading=jnp.array([0.0]), speed=state.searchers.speed + ), + targets=state.targets, + key=state.key, + ) + state, timestep = env.step(state, jnp.zeros((1, 2))) + assert state.targets.found[0] + assert timestep.reward[0] == 1 + + # Once detected should remain detected + state = State( + searchers=AgentState( + pos=jnp.array([[0.0, 0.0]]), + heading=state.searchers.heading, + speed=state.searchers.speed, + ), + targets=state.targets, + key=state.key, + ) + state, timestep = env.step(state, jnp.zeros((1, 2))) + assert state.targets.found[0] + assert timestep.reward[0] == 0 + + +def test_search_and_rescue_render(monkeypatch: pytest.MonkeyPatch, env: SearchAndRescue) -> None: + """Check that the render method builds the figure but does not display it.""" + monkeypatch.setattr(plt, "show", lambda fig: None) + step_fn = jax.jit(env.step) + state, timestep = env.reset(jax.random.PRNGKey(0)) + action = env.action_spec.generate_value() + state, timestep = step_fn(state, action) + env.render(state) + env.close() + + +def test_search_and_rescue__animation(env: SearchAndRescue, tmpdir: py.path.local) -> None: + """Check that the animation method creates the animation correctly and can save to a gif.""" + step_fn = jax.jit(env.step) + state, _ = env.reset(jax.random.PRNGKey(0)) + states = [state] + action = env.action_spec.generate_value() + state, _ = step_fn(state, action) + states.append(state) + animation = env.animate(states, interval=200, save_path=None) + assert isinstance(animation, matplotlib.animation.Animation) + + path = str(tmpdir.join("/anim.gif")) + animation.save(path, writer=matplotlib.animation.PillowWriter(fps=10), dpi=60) diff --git a/jumanji/environments/swarms/predator_prey/generator.py b/jumanji/environments/swarms/search_and_rescue/generator.py similarity index 50% rename from jumanji/environments/swarms/predator_prey/generator.py rename to jumanji/environments/swarms/search_and_rescue/generator.py index 8c5549f9e..fb33dc870 100644 --- a/jumanji/environments/swarms/predator_prey/generator.py +++ b/jumanji/environments/swarms/search_and_rescue/generator.py @@ -16,33 +16,31 @@ import chex import jax +import jax.numpy as jnp from jumanji.environments.swarms.common.types import AgentParams from jumanji.environments.swarms.common.updates import init_state -from jumanji.environments.swarms.predator_prey.types import State +from jumanji.environments.swarms.search_and_rescue.types import State, TargetState class Generator(abc.ABC): - def __init__(self, num_predators: int, num_prey: int) -> None: - """Interface for instance generation for the `PredatorPrey` environment. + def __init__(self, num_searchers: int, num_targets: int) -> None: + """Interface for instance generation for the `SearchAndRescue` environment. Args: - num_predators: Number of predator agents - num_prey: Number of prey agents + num_searchers: Number of searcher agents + num_targets: Number of search targets """ - self.num_predators = num_predators - self.num_prey = num_prey + self.num_searchers = num_searchers + self.num_targets = num_targets @abc.abstractmethod - def __call__( - self, key: chex.PRNGKey, predator_params: AgentParams, prey_params: AgentParams - ) -> State: + def __call__(self, key: chex.PRNGKey, searcher_params: AgentParams) -> State: """Generate initial agent positions and velocities. Args: key: random key. - predator_params: Predator `AgentParams`. - prey_params: Prey `AgentParams`. + searcher_params: Searcher `AgentParams`. Returns: Initial agent `State`. @@ -50,21 +48,24 @@ def __call__( class RandomGenerator(Generator): - def __call__( - self, key: chex.PRNGKey, predator_params: AgentParams, prey_params: AgentParams - ) -> State: - """Generate random initial agent positions and velocities. + def __call__(self, key: chex.PRNGKey, searcher_params: AgentParams) -> State: + """Generate random initial agent positions and velocities, and random target positions. Args: key: random key. - predator_params: Predator `AgentParams`. - prey_params: Prey `AgentParams`. + searcher_params: Searcher `AgentParams`. Returns: state: the generated state. """ - key, predator_key, prey_key = jax.random.split(key, num=3) - predator_state = init_state(self.num_predators, predator_params, predator_key) - prey_state = init_state(self.num_prey, prey_params, prey_key) - state = State(predators=predator_state, prey=prey_state, key=key) + key, searcher_key, target_key = jax.random.split(key, num=3) + searcher_state = init_state(self.num_searchers, searcher_params, searcher_key) + target_pos = jax.random.uniform(target_key, (self.num_targets, 2)) + state = State( + searchers=searcher_state, + targets=TargetState( + pos=target_pos, found=jnp.full((self.num_targets,), False, dtype=bool) + ), + key=key, + ) return state diff --git a/jumanji/environments/swarms/predator_prey/types.py b/jumanji/environments/swarms/search_and_rescue/types.py similarity index 55% rename from jumanji/environments/swarms/predator_prey/types.py rename to jumanji/environments/swarms/search_and_rescue/types.py index c5b06e35b..d24603d7e 100644 --- a/jumanji/environments/swarms/predator_prey/types.py +++ b/jumanji/environments/swarms/search_and_rescue/types.py @@ -23,67 +23,49 @@ from jumanji.environments.swarms.common.types import AgentState +@dataclass +class TargetState: + pos: chex.Array + found: chex.Array + + @dataclass class State: """ - predators: Predator agent states. - prey: Prey agent states. + searchers: Searcher agent states. + targets: Search target state. key: JAX random key. step: Environment step number """ - predators: AgentState - prey: AgentState + searchers: AgentState + targets: TargetState key: chex.PRNGKey step: int = 0 -@dataclass -class PredatorPreyStruct: - """ - General struct for predator prey structured data - (e.g. rewards or observations) - - predators: Array of data per predator agent. - prey: Array of data per prey agents. - """ - - predators: chex.Array # (num_predators, ...) - prey: chex.Array # (num_prey, ...) - - class Observation(NamedTuple): """ - Individual observations for predator and prey agents. + Individual observations for searching agents and information + on number of remaining time and targets to be found. Each agent generates an independent observation, an array of values representing the distance along a ray from the agent to the nearest neighbour, with each cell representing a ray angle (with `num_vision` rays evenly distributed over the agents - field of vision). Prey and prey agent types are visualised - independently to allow agents to observe both local position and type. + field of vision). - For example if a prey agent sees a predator straight ahead and + For example if an agent sees another agent straight ahead and `num_vision = 5` then the observation array could be ``` - [1.0, 1.0, 0.5, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] - ``` - - or if it observes another prey agent - - ``` - [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.5, 1.0, 1.0] + [1.0, 1.0, 0.5, 1.0, 1.0] ``` where `1.0` indicates there is no agents along that ray, and `0.5` is the normalised distance to the other agent. - - - `predators`: jax array (float) of shape `(num_predators, 2 * num_vision)` - in the unit interval. - - `prey`: jax array (float) of shape `(num_prey, 2 * num_vision)` in the - unit interval. """ - predators: chex.Array # (num_predators, num_vision) - prey: chex.Array # (num_prey, num_vision) + searcher_views: chex.Array # (num_searchers, num_vision) + target_remaining: chex.Array # () + time_remaining: chex.Array # () diff --git a/jumanji/environments/swarms/predator_prey/viewer.py b/jumanji/environments/swarms/search_and_rescue/viewer.py similarity index 75% rename from jumanji/environments/swarms/predator_prey/viewer.py rename to jumanji/environments/swarms/search_and_rescue/viewer.py index 3b3422275..72470e687 100644 --- a/jumanji/environments/swarms/predator_prey/viewer.py +++ b/jumanji/environments/swarms/search_and_rescue/viewer.py @@ -17,24 +17,26 @@ import jax.numpy as jnp import matplotlib.animation import matplotlib.pyplot as plt +import numpy as np from matplotlib.layout_engine import TightLayoutEngine import jumanji import jumanji.environments from jumanji.environments.swarms.common.viewer import draw_agents, format_plot -from jumanji.environments.swarms.predator_prey.types import State +from jumanji.environments.swarms.search_and_rescue.types import State from jumanji.viewer import Viewer -class PredatorPreyViewer(Viewer): +class SearchAndRescueViewer(Viewer): def __init__( self, - figure_name: str = "PredatorPrey", + figure_name: str = "SearchAndRescue", figure_size: Tuple[float, float] = (6.0, 6.0), - predator_color: str = "red", - prey_color: str = "green", + searcher_color: str = "blue", + target_found_color: str = "green", + target_lost_color: str = "red", ) -> None: - """Viewer for the `PredatorPrey` environment. + """Viewer for the `SearchAndRescue` environment. Args: figure_name: the window name to be used when initialising the window. @@ -42,8 +44,8 @@ def __init__( """ self._figure_name = figure_name self._figure_size = figure_size - self.predator_color = predator_color - self.prey_color = prey_color + self.searcher_color = searcher_color + self.target_colors = np.array([target_lost_color, target_found_color]) self._animation: Optional[matplotlib.animation.Animation] = None def render(self, state: State) -> None: @@ -77,18 +79,21 @@ def animate( fig, ax = plt.subplots(num=f"{self._figure_name}Anim", figsize=self._figure_size) fig, ax = format_plot(fig, ax) - predators_quiver = draw_agents(ax, states[0].predators, self.predator_color) - prey_quiver = draw_agents(ax, states[0].prey, self.prey_color) + searcher_quiver = draw_agents(ax, states[0].searchers, self.searcher_color) + target_scatter = ax.scatter( + states[0].targets.pos[:, 0], states[0].targets.pos[:, 1], marker="o" + ) def make_frame(state: State) -> Any: - # Rather than redraw just update the quivers properties - predators_quiver.set_offsets(state.predators.pos) - predators_quiver.set_UVC( - jnp.cos(state.predators.heading), jnp.sin(state.predators.heading) + # Rather than redraw just update the quivers and scatter properties + searcher_quiver.set_offsets(state.searchers.pos) + searcher_quiver.set_UVC( + jnp.cos(state.searchers.heading), jnp.sin(state.searchers.heading) ) - prey_quiver.set_offsets(state.prey.pos) - prey_quiver.set_UVC(jnp.cos(state.prey.heading), jnp.sin(state.prey.heading)) - return ((predators_quiver, prey_quiver),) + target_colors = self.target_colors[state.targets.found.astype(jnp.int32)] + target_scatter.set_offsets(state.targets.pos) + target_scatter.set_color(target_colors) + return ((searcher_quiver, target_scatter),) matplotlib.rc("animation", html="jshtml") self._animation = matplotlib.animation.FuncAnimation( @@ -114,8 +119,11 @@ def close(self) -> None: def _draw(self, ax: plt.Axes, state: State) -> None: ax.clear() - draw_agents(ax, state.predators, self.predator_color) - draw_agents(ax, state.prey, self.prey_color) + draw_agents(ax, state.searchers, self.searcher_color) + target_colors = self.target_colors[state.targets.found.astype(jnp.int32)] + ax.scatter( + state.targets.pos[:, 0], state.targets.pos[:, 1], marker="o", color=target_colors + ) def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: exists = plt.fignum_exists(self._figure_name) diff --git a/mkdocs.yml b/mkdocs.yml index 83ccfb049..8e2ca2f87 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -44,7 +44,7 @@ nav: - Snake: environments/snake.md - TSP: environments/tsp.md - Swarms: - - PredatorPrey: environments/predator_prey.md + - SearchAndRescue: environments/search_and_rescue.md - User Guides: - Advanced Usage: guides/advanced_usage.md - Registration: guides/registration.md @@ -80,7 +80,7 @@ nav: - Snake: api/environments/snake.md - TSP: api/environments/tsp.md - Swarms: - - PredatorPrey: api/environments/predator_prey.md + - SearchAndRescue: api/environments/search_and_rescue.md - Wrappers: api/wrappers.md - Types: api/types.md From 34beab67acad44ab948be3163712c8347216738b Mon Sep 17 00:00:00 2001 From: zombie-einstein <13398815+zombie-einstein@users.noreply.github.com> Date: Thu, 14 Nov 2024 11:04:11 +0000 Subject: [PATCH 06/19] fix: PR fixes (#4) * refactor: Rename to targets_remaining * docs: Formatting and expand docs * refactor: Move target and reward checks into utils module * fix: Set agent and target numbers via generator * refactor: Terminate episode if all targets found * test: Add swarms.common tests * refactor: Move agent initialisation into generator * test: Add environment utility tests --- docs/environments/search_and_rescue.md | 27 ++- .../environments/swarms/common/test_common.py | 173 ++++++++++++++++++ jumanji/environments/swarms/common/updates.py | 25 --- .../swarms/search_and_rescue/env.py | 89 +++------ .../swarms/search_and_rescue/env_test.py | 41 +++-- .../swarms/search_and_rescue/generator.py | 23 ++- .../swarms/search_and_rescue/test_utils.py | 107 +++++++++++ .../swarms/search_and_rescue/types.py | 2 +- .../swarms/search_and_rescue/utils.py | 88 +++++++++ 9 files changed, 459 insertions(+), 116 deletions(-) create mode 100644 jumanji/environments/swarms/common/test_common.py create mode 100644 jumanji/environments/swarms/search_and_rescue/test_utils.py create mode 100644 jumanji/environments/swarms/search_and_rescue/utils.py diff --git a/docs/environments/search_and_rescue.md b/docs/environments/search_and_rescue.md index a124b4470..65c98292b 100644 --- a/docs/environments/search_and_rescue.md +++ b/docs/environments/search_and_rescue.md @@ -35,20 +35,33 @@ space is a uniform space with unit dimensions, and wrapped at the boundaries. where `1.0` indicates there is no agents along that ray, and `0.5` is the normalised distance to the other agent. -- `target_remaining`: float in the range [0, 1]. The normalised number of targets +- `targets_remaining`: float in the range `[0, 1]`. The normalised number of targets remaining to be detected (i.e. 1.0 when no targets have been found). -- `time_remaining`: float in the range [0, 1]. The normalised number of steps remaining +- `time_remaining`: float in the range `[0, 1]`. The normalised number of steps remaining to locate the targets (i.e. 0.0 at the end of the episode). ## Actions -Jax array (float) of `(num_searchers, 2)` in the range [-1, 1]. Each entry in the +Jax array (float) of `(num_searchers, 2)` in the range `[-1, 1]`. Each entry in the array represents an update of each agents velocity in the next step. Searching agents -update their velocity each step by rotating and accelerating/decelerating. Values -are clipped to the range `[-1, 1]` and then scaled by max rotation and acceleration -parameters. Agents are restricted to velocities within a fixed range of speeds. +update their velocity each step by rotating and accelerating/decelerating, where the +values are `[rotation, acceleration]`. Values are clipped to the range `[-1, 1]` +and then scaled by max rotation and acceleration parameters, i.e. the new values each +step are given by + +``` +heading = heading + max_rotation * action[0] +``` + +and speed + +``` +speed = speed + max_acceleration * action[1] +``` + +Once applied, agent speeds are clipped to velocities within a fixed range of speeds. ## Rewards -Jax array (float) of `(num_searchers, 2)`. Rewards are generated for each agent individually. +Jax array (float) of `(num_searchers,)`. Rewards are generated for each agent individually. Agents are rewarded 1.0 for locating a target that has not already been detected. diff --git a/jumanji/environments/swarms/common/test_common.py b/jumanji/environments/swarms/common/test_common.py new file mode 100644 index 000000000..93128360b --- /dev/null +++ b/jumanji/environments/swarms/common/test_common.py @@ -0,0 +1,173 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple + +import jax +import jax.numpy as jnp +import matplotlib +import matplotlib.pyplot as plt +import pytest + +from jumanji.environments.swarms.common import types, updates, viewer + + +@pytest.fixture +def params() -> types.AgentParams: + return types.AgentParams( + max_rotate=0.5, + max_accelerate=0.01, + min_speed=0.01, + max_speed=0.05, + view_angle=0.5, + ) + + +@pytest.mark.parametrize( + "heading, speed, actions, expected", + [ + [0.0, 0.01, [1.0, 0.0], (0.5 * jnp.pi, 0.01)], + [0.0, 0.01, [-1.0, 0.0], (1.5 * jnp.pi, 0.01)], + [jnp.pi, 0.01, [1.0, 0.0], (1.5 * jnp.pi, 0.01)], + [jnp.pi, 0.01, [-1.0, 0.0], (0.5 * jnp.pi, 0.01)], + [1.75 * jnp.pi, 0.01, [1.0, 0.0], (0.25 * jnp.pi, 0.01)], + [0.0, 0.01, [0.0, 1.0], (0.0, 0.02)], + [0.0, 0.01, [0.0, -1.0], (0.0, 0.01)], + [0.0, 0.02, [0.0, -1.0], (0.0, 0.01)], + [0.0, 0.05, [0.0, -1.0], (0.0, 0.04)], + [0.0, 0.05, [0.0, 1.0], (0.0, 0.05)], + ], +) +def test_velocity_update( + params: types.AgentParams, + heading: float, + speed: float, + actions: List[float], + expected: Tuple[float, float], +) -> None: + key = jax.random.PRNGKey(101) + + state = types.AgentState( + pos=jnp.zeros((1, 2)), + heading=jnp.array([heading]), + speed=jnp.array([speed]), + ) + actions = jnp.array([actions]) + + new_heading, new_speed = updates.update_velocity(key, params, (actions, state)) + + assert jnp.isclose(new_heading[0], expected[0]) + assert jnp.isclose(new_speed[0], expected[1]) + + +@pytest.mark.parametrize( + "pos, heading, speed, expected", + [ + [[0.0, 0.5], 0.0, 0.1, [0.1, 0.5]], + [[0.0, 0.5], jnp.pi, 0.1, [0.9, 0.5]], + [[0.5, 0.0], 0.5 * jnp.pi, 0.1, [0.5, 0.1]], + [[0.5, 0.0], 1.5 * jnp.pi, 0.1, [0.5, 0.9]], + ], +) +def test_move(pos: List[float], heading: float, speed: float, expected: List[float]) -> None: + pos = jnp.array(pos) + new_pos = updates.move(pos, heading, speed) + + assert jnp.allclose(new_pos, jnp.array(expected)) + + +@pytest.mark.parametrize( + "pos, heading, speed, actions, expected_pos, expected_heading, expected_speed", + [ + [[0.0, 0.5], 0.0, 0.01, [0.0, 0.0], [0.01, 0.5], 0.0, 0.01], + [[0.5, 0.0], 0.0, 0.01, [1.0, 0.0], [0.5, 0.01], 0.5 * jnp.pi, 0.01], + [[0.5, 0.0], 0.0, 0.01, [-1.0, 0.0], [0.5, 0.99], 1.5 * jnp.pi, 0.01], + [[0.0, 0.5], 0.0, 0.01, [0.0, 1.0], [0.02, 0.5], 0.0, 0.02], + [[0.0, 0.5], 0.0, 0.01, [0.0, -1.0], [0.01, 0.5], 0.0, 0.01], + [[0.0, 0.5], 0.0, 0.05, [0.0, 1.0], [0.05, 0.5], 0.0, 0.05], + ], +) +def test_state_update( + params: types.AgentParams, + pos: List[float], + heading: float, + speed: float, + actions: List[float], + expected_pos: List[float], + expected_heading: float, + expected_speed: float, +) -> None: + key = jax.random.PRNGKey(101) + + state = types.AgentState( + pos=jnp.array([pos]), + heading=jnp.array([heading]), + speed=jnp.array([speed]), + ) + actions = jnp.array([actions]) + + new_state = updates.update_state(key, params, state, actions) + + assert isinstance(new_state, types.AgentState) + assert jnp.allclose(new_state.pos, jnp.array([expected_pos])) + assert jnp.allclose(new_state.heading, jnp.array([expected_heading])) + assert jnp.allclose(new_state.speed, jnp.array([expected_speed])) + + +@pytest.mark.parametrize( + "pos, view_angle, expected", + [ + [[0.05, 0.0], 0.5, [1.0, 1.0, 0.5, 1.0, 1.0]], + [[0.0, 0.05], 0.5, [0.5, 1.0, 1.0, 1.0, 1.0]], + [[0.0, 0.95], 0.5, [1.0, 1.0, 1.0, 1.0, 0.5]], + [[0.95, 0.0], 0.5, [1.0, 1.0, 1.0, 1.0, 1.0]], + [[0.05, 0.0], 0.25, [1.0, 1.0, 0.5, 1.0, 1.0]], + [[0.0, 0.05], 0.25, [1.0, 1.0, 1.0, 1.0, 1.0]], + [[0.0, 0.95], 0.25, [1.0, 1.0, 1.0, 1.0, 1.0]], + [[0.01, 0.0], 0.5, [1.0, 1.0, 0.1, 1.0, 1.0]], + ], +) +def test_view(pos: List[float], view_angle: float, expected: List[float]) -> None: + state_a = types.AgentState( + pos=jnp.zeros((2,)), + heading=0.0, + speed=0.0, + ) + + state_b = types.AgentState( + pos=jnp.array(pos), + heading=0.0, + speed=0.0, + ) + + obs = updates.view(None, (view_angle, 0.02), state_a, state_b, n_view=5, i_range=0.1) + assert jnp.allclose(obs, jnp.array(expected)) + + +def test_viewer_utils() -> None: + f, ax = plt.subplots() + f, ax = viewer.format_plot(f, ax) + + assert isinstance(f, matplotlib.figure.Figure) + assert isinstance(ax, matplotlib.axes.Axes) + + state = types.AgentState( + pos=jnp.zeros((3, 2)), + heading=jnp.zeros((3,)), + speed=jnp.zeros((3,)), + ) + + quiver = viewer.draw_agents(ax, state, "red") + + assert isinstance(quiver, matplotlib.quiver.Quiver) diff --git a/jumanji/environments/swarms/common/updates.py b/jumanji/environments/swarms/common/updates.py index 353fdcf9e..665f44e34 100644 --- a/jumanji/environments/swarms/common/updates.py +++ b/jumanji/environments/swarms/common/updates.py @@ -70,31 +70,6 @@ def move(pos: chex.Array, heading: chex.Array, speed: chex.Array) -> chex.Array: return (pos + d_pos) % 1.0 -def init_state(n: int, params: types.AgentParams, key: chex.PRNGKey) -> types.AgentState: - """ - Randomly initialise state of a group of agents - - Args: - n: Number of agents to initialise. - params: Agent parameters. - key: JAX random key. - - Returns: - AgentState: Random agent states (i.e. position, headings, and speeds) - """ - k1, k2, k3 = jax.random.split(key, 3) - - positions = jax.random.uniform(k1, (n, 2)) - speeds = jax.random.uniform(k2, (n,), minval=params.min_speed, maxval=params.max_speed) - headings = jax.random.uniform(k3, (n,), minval=0.0, maxval=2.0 * jnp.pi) - - return types.AgentState( - pos=positions, - speed=speeds, - heading=headings, - ) - - def update_state( key: chex.PRNGKey, params: types.AgentParams, diff --git a/jumanji/environments/swarms/search_and_rescue/env.py b/jumanji/environments/swarms/search_and_rescue/env.py index a1fa15498..12ac5a341 100644 --- a/jumanji/environments/swarms/search_and_rescue/env.py +++ b/jumanji/environments/swarms/search_and_rescue/env.py @@ -19,13 +19,13 @@ import jax import jax.numpy as jnp from esquilax.transforms import spatial -from esquilax.utils import shortest_vector from matplotlib.animation import FuncAnimation from jumanji import specs from jumanji.env import Environment -from jumanji.environments.swarms.common.types import AgentParams, AgentState +from jumanji.environments.swarms.common.types import AgentParams from jumanji.environments.swarms.common.updates import update_state, view +from jumanji.environments.swarms.search_and_rescue import utils from jumanji.environments.swarms.search_and_rescue.dynamics import RandomWalk, TargetDynamics from jumanji.environments.swarms.search_and_rescue.generator import Generator, RandomGenerator from jumanji.environments.swarms.search_and_rescue.types import Observation, State, TargetState @@ -34,42 +34,6 @@ from jumanji.viewer import Viewer -def _inner_found_check( - _key: chex.PRNGKey, - searcher_view_angle: float, - target_pos: chex.Array, - searcher: AgentState, -) -> chex.Array: - """ - Return True if searcher can view the target. - """ - dx = shortest_vector(searcher.pos, target_pos) - phi = jnp.arctan2(dx[1], dx[0]) % (2 * jnp.pi) - dh = shortest_vector(phi, searcher.heading, 2 * jnp.pi) - return (dh >= -searcher_view_angle) & (dh <= searcher_view_angle) - - -def _inner_reward_check( - _key: chex.PRNGKey, - searcher_view_angle: float, - searcher: AgentState, - target: TargetState, -) -> chex.Array: - """ - Return +1.0 reward if the target is within the searcher view angle. - """ - dx = shortest_vector(searcher.pos, target.pos) - phi = jnp.arctan2(dx[1], dx[0]) % (2 * jnp.pi) - dh = shortest_vector(phi, searcher.heading, 2 * jnp.pi) - can_see = (dh >= -searcher_view_angle) & (dh <= searcher_view_angle) - return jax.lax.cond( - # (not target.found) & can_see, - can_see, - lambda: 1.0, - lambda: 0.0, - ) - - class SearchAndRescue(Environment): """A multi-agent search environment @@ -83,7 +47,7 @@ class SearchAndRescue(Environment): - observation: `Observation` searcher_views: jax array (float) of shape (num_searchers, num_vision) individual local views of positions of other searching agents. - target_remaining: (float) Number of targets remaining to be found from + targets_remaining: (float) Number of targets remaining to be found from the total scaled to the range [0, 1] (i.e. a value of 1.0 indicates all the targets are still to be found). time_remaining: (float) Steps remaining to find agents, scaled to the @@ -118,8 +82,6 @@ class SearchAndRescue(Environment): ```python from jumanji.environments import SearchAndRescue env = SearchAndRescue( - num_searchers=10, - num_targets=20, searcher_vision_range=0.1, target_contact_range=0.01, num_vision=40, @@ -141,8 +103,6 @@ class SearchAndRescue(Environment): def __init__( self, - num_searchers: int, - num_targets: int, searcher_vision_range: float, target_contact_range: float, num_vision: int, @@ -168,8 +128,6 @@ def __init__( agent interactions. Args: - num_searchers: Number of searching agents. - num_targets: Number of search targets. searcher_vision_range: Search agent vision range. target_contact_range: Range at which a searcher can 'find' a target. num_vision: Number of cells/subdivisions in agent @@ -194,10 +152,9 @@ def __init__( target_dynamics: target_dynamics: Target object dynamics model, implemented as a `TargetDynamics` interface. Defaults to `RandomWalk`. - generator: Initial state `Generator` instance. Defaults to `RandomGenerator`. + generator: Initial state `Generator` instance. Defaults to `RandomGenerator` + with 20 targets and 10 searchers. """ - self.num_searchers = num_searchers - self.num_targets = num_targets self.searcher_vision_range = searcher_vision_range self.target_contact_range = target_contact_range self.num_vision = num_vision @@ -210,24 +167,24 @@ def __init__( view_angle=searcher_view_angle, ) self.max_steps = max_steps - super().__init__() self._viewer = viewer or SearchAndRescueViewer() self._target_dynamics = target_dynamics or RandomWalk(0.01) - self._generator = generator or RandomGenerator(num_searchers, num_targets) + self.generator = generator or RandomGenerator(num_targets=20, num_searchers=10) + super().__init__() def __repr__(self) -> str: return "\n".join( [ "Search & rescue multi-agent environment:", - f" - num searchers: {self.num_searchers}", - f" - num targets: {self.num_targets}", + f" - num searchers: {self.generator.num_searchers}", + f" - num targets: {self.generator.num_targets}", f" - search vision range: {self.searcher_vision_range}", f" - target contact range: {self.target_contact_range}", f" - num vision: {self.num_vision}", f" - agent radius: {self.agent_radius}", f" - max steps: {self.max_steps}," f" - target dynamics: {self._target_dynamics.__class__.__name__}", - f" - generator: {self._generator.__class__.__name__}", + f" - generator: {self.generator.__class__.__name__}", ] ) @@ -241,7 +198,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: state: Initial environment state. timestep: TimeStep with individual search agent views. """ - state = self._generator(key, self.searcher_params) + state = self.generator(key, self.searcher_params) timestep = restart(observation=self._state_to_observation(state)) return state, timestep @@ -264,8 +221,10 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser # Ensure target positions are wrapped target_pos = self._target_dynamics(target_key, state.targets.pos) % 1.0 # Grant searchers rewards if in range and not already detected + # spatial maps the has_found_target function over all pair of targets and + # searchers within range of each other and sums rewards per agent. rewards = spatial( - _inner_reward_check, + utils.has_found_target, reduction=jnp.add, default=0.0, i_range=self.target_contact_range, @@ -278,8 +237,10 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser pos_b=target_pos, ) # Mark targets as found if with contact range and view angle of a searcher + # spatial maps the has_been_found function over all pair of targets and + # searchers within range of each other targets_found = spatial( - _inner_found_check, + utils.has_been_found, reduction=jnp.logical_or, default=False, i_range=self.target_contact_range, @@ -301,7 +262,7 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser ) observation = self._state_to_observation(state) timestep = jax.lax.cond( - state.step >= self.max_steps, + state.step >= self.max_steps | jnp.all(targets_found), termination, transition, rewards, @@ -328,7 +289,7 @@ def _state_to_observation(self, state: State) -> Observation: return Observation( searcher_views=searcher_views, - target_remaining=1.0 - jnp.sum(state.targets.found) / self.num_targets, + targets_remaining=1.0 - jnp.sum(state.targets.found) / self.generator.num_targets, time_remaining=1.0 - state.step / (self.max_steps + 1), ) @@ -344,7 +305,7 @@ def observation_spec(self) -> specs.Spec[Observation]: observation_spec: Search-and-rescue observation spec """ searcher_views = specs.BoundedArray( - shape=(self.num_searchers, self.num_vision), + shape=(self.generator.num_searchers, self.num_vision), minimum=0.0, maximum=1.0, dtype=float, @@ -354,8 +315,8 @@ def observation_spec(self) -> specs.Spec[Observation]: Observation, "ObservationSpec", searcher_views=searcher_views, - target_remaining=specs.BoundedArray( - shape=(), minimum=0.0, maximum=1.0, name="target_remaining", dtype=float + targets_remaining=specs.BoundedArray( + shape=(), minimum=0.0, maximum=1.0, name="targets_remaining", dtype=float ), time_remaining=specs.BoundedArray( shape=(), minimum=0.0, maximum=1.0, name="time_remaining", dtype=float @@ -374,7 +335,7 @@ def action_spec(self) -> specs.BoundedArray: action_spec: Action array spec """ return specs.BoundedArray( - shape=(self.num_searchers, 2), + shape=(self.generator.num_searchers, 2), minimum=-1.0, maximum=1.0, dtype=float, @@ -390,9 +351,9 @@ def reward_spec(self) -> specs.BoundedArray: reward_spec: Reward array spec. """ return specs.BoundedArray( - shape=(self.num_searchers,), + shape=(self.generator.num_searchers,), minimum=0.0, - maximum=float(self.num_targets), + maximum=float(self.generator.num_targets), dtype=float, ) diff --git a/jumanji/environments/swarms/search_and_rescue/env_test.py b/jumanji/environments/swarms/search_and_rescue/env_test.py index 8f0de699f..92bb34144 100644 --- a/jumanji/environments/swarms/search_and_rescue/env_test.py +++ b/jumanji/environments/swarms/search_and_rescue/env_test.py @@ -40,8 +40,6 @@ @pytest.fixture def env() -> SearchAndRescue: return SearchAndRescue( - num_searchers=10, - num_targets=20, searcher_vision_range=SEARCHER_VISION_RANGE, target_contact_range=TARGET_CONTACT_RANGE, num_vision=11, @@ -65,19 +63,21 @@ def test_env_init(env: SearchAndRescue) -> None: assert isinstance(state, State) assert isinstance(state.searchers, AgentState) - assert state.searchers.pos.shape == (env.num_searchers, 2) - assert state.searchers.speed.shape == (env.num_searchers,) - assert state.searchers.speed.shape == (env.num_searchers,) + assert state.searchers.pos.shape == (env.generator.num_searchers, 2) + assert state.searchers.speed.shape == (env.generator.num_searchers,) + assert state.searchers.speed.shape == (env.generator.num_searchers,) assert isinstance(state.targets, TargetState) - assert state.targets.pos.shape == (env.num_targets, 2) - assert state.targets.found.shape == (env.num_targets,) - assert jnp.array_equal(state.targets.found, jnp.full((env.num_targets,), False, dtype=bool)) + assert state.targets.pos.shape == (env.generator.num_targets, 2) + assert state.targets.found.shape == (env.generator.num_targets,) + assert jnp.array_equal( + state.targets.found, jnp.full((env.generator.num_targets,), False, dtype=bool) + ) assert state.step == 0 assert isinstance(timestep.observation, Observation) assert timestep.observation.searcher_views.shape == ( - env.num_searchers, + env.generator.num_searchers, env.num_vision, ) assert timestep.step_type == StepType.FIRST @@ -97,7 +97,9 @@ def step( ) -> Tuple[Tuple[chex.PRNGKey, State], Tuple[State, TimeStep[Observation]]]: k, state = carry k, k_search = jax.random.split(k) - actions = jax.random.uniform(k_search, (env.num_searchers, 2), minval=-1.0, maxval=1.0) + actions = jax.random.uniform( + k_search, (env.generator.num_searchers, 2), minval=-1.0, maxval=1.0 + ) new_state, timestep = env.step(state, actions) return (k, new_state), (state, timestep) @@ -108,19 +110,19 @@ def step( assert isinstance(state_history, State) - assert state_history.searchers.pos.shape == (n_steps, env.num_searchers, 2) + assert state_history.searchers.pos.shape == (n_steps, env.generator.num_searchers, 2) assert jnp.all((0.0 <= state_history.searchers.pos) & (state_history.searchers.pos <= 1.0)) - assert state_history.searchers.speed.shape == (n_steps, env.num_searchers) + assert state_history.searchers.speed.shape == (n_steps, env.generator.num_searchers) assert jnp.all( (env.searcher_params.min_speed <= state_history.searchers.speed) & (state_history.searchers.speed <= env.searcher_params.max_speed) ) - assert state_history.searchers.speed.shape == (n_steps, env.num_searchers) + assert state_history.searchers.speed.shape == (n_steps, env.generator.num_searchers) assert jnp.all( (0.0 <= state_history.searchers.heading) & (state_history.searchers.heading <= 2.0 * jnp.pi) ) - assert state_history.targets.pos.shape == (n_steps, env.num_targets, 2) + assert state_history.targets.pos.shape == (n_steps, env.generator.num_targets, 2) assert jnp.all((0.0 <= state_history.targets.pos) & (state_history.targets.pos <= 1.0)) @@ -129,7 +131,9 @@ def test_env_does_not_smoke(env: SearchAndRescue) -> None: env.max_steps = 10 def select_action(action_key: chex.PRNGKey, _state: Observation) -> chex.Array: - return jax.random.uniform(action_key, (env.num_searchers, 2), minval=-1.0, maxval=1.0) + return jax.random.uniform( + action_key, (env.generator.num_searchers, 2), minval=-1.0, maxval=1.0 + ) check_env_does_not_smoke(env, select_action=select_action) @@ -229,7 +233,12 @@ def test_target_detection(env: SearchAndRescue) -> None: assert state.targets.found[0] assert timestep.reward[0] == 1 - # Once detected should remain detected + # Searcher should only get rewards once + state, timestep = env.step(state, jnp.zeros((1, 2))) + assert state.targets.found[0] + assert timestep.reward[0] == 0 + + # Once detected target should remain detected if agent moves away state = State( searchers=AgentState( pos=jnp.array([[0.0, 0.0]]), diff --git a/jumanji/environments/swarms/search_and_rescue/generator.py b/jumanji/environments/swarms/search_and_rescue/generator.py index fb33dc870..9de816086 100644 --- a/jumanji/environments/swarms/search_and_rescue/generator.py +++ b/jumanji/environments/swarms/search_and_rescue/generator.py @@ -18,8 +18,7 @@ import jax import jax.numpy as jnp -from jumanji.environments.swarms.common.types import AgentParams -from jumanji.environments.swarms.common.updates import init_state +from jumanji.environments.swarms.common.types import AgentParams, AgentState from jumanji.environments.swarms.search_and_rescue.types import State, TargetState @@ -59,8 +58,26 @@ def __call__(self, key: chex.PRNGKey, searcher_params: AgentParams) -> State: state: the generated state. """ key, searcher_key, target_key = jax.random.split(key, num=3) - searcher_state = init_state(self.num_searchers, searcher_params, searcher_key) + + k_pos, k_head, k_speed = jax.random.split(searcher_key, 3) + positions = jax.random.uniform(k_pos, (self.num_searchers, 2)) + headings = jax.random.uniform( + k_head, (self.num_searchers,), minval=0.0, maxval=2.0 * jnp.pi + ) + speeds = jax.random.uniform( + k_speed, + (self.num_searchers,), + minval=searcher_params.min_speed, + maxval=searcher_params.max_speed, + ) + searcher_state = AgentState( + pos=positions, + speed=speeds, + heading=headings, + ) + target_pos = jax.random.uniform(target_key, (self.num_targets, 2)) + state = State( searchers=searcher_state, targets=TargetState( diff --git a/jumanji/environments/swarms/search_and_rescue/test_utils.py b/jumanji/environments/swarms/search_and_rescue/test_utils.py new file mode 100644 index 000000000..c2f18b9f2 --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/test_utils.py @@ -0,0 +1,107 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import jax +import jax.numpy as jnp +import pytest + +from jumanji.environments.swarms.common.types import AgentParams, AgentState +from jumanji.environments.swarms.search_and_rescue.dynamics import RandomWalk, TargetDynamics +from jumanji.environments.swarms.search_and_rescue.generator import Generator, RandomGenerator +from jumanji.environments.swarms.search_and_rescue.types import State, TargetState +from jumanji.environments.swarms.search_and_rescue.utils import has_been_found, has_found_target + + +def test_random_generator() -> None: + key = jax.random.PRNGKey(101) + params = AgentParams( + max_rotate=0.5, + max_accelerate=0.01, + min_speed=0.01, + max_speed=0.05, + view_angle=0.5, + ) + generator = RandomGenerator(num_searchers=100, num_targets=101) + + assert isinstance(generator, Generator) + + state = generator(key, params) + + assert isinstance(state, State) + assert state.searchers.pos.shape == (generator.num_searchers, 2) + assert jnp.all(0.0 <= state.searchers.pos) and jnp.all(state.searchers.pos <= 1.0) + assert state.targets.pos.shape == (generator.num_targets, 2) + assert jnp.all(0.0 <= state.targets.pos) and jnp.all(state.targets.pos <= 1.0) + assert not jnp.any(state.targets.found) + assert state.step == 0 + + +def test_random_walk_dynamics() -> None: + n_targets = 50 + key = jax.random.PRNGKey(101) + s0 = jnp.full((n_targets, 2), 0.5) + + dynamics = RandomWalk(0.1) + assert isinstance(dynamics, TargetDynamics) + s1 = dynamics(key, s0) + + assert s1.shape == (n_targets, 2) + assert jnp.all(jnp.abs(s0 - s1) < 0.1) + + +@pytest.mark.parametrize( + "pos, heading, view_angle, target_state, expected", + [ + ([0.1, 0.0], 0.0, 0.5, False, False), + ([0.1, 0.0], jnp.pi, 0.5, False, True), + ([0.1, 0.0], jnp.pi, 0.5, True, True), + ([0.9, 0.0], jnp.pi, 0.5, False, False), + ([0.9, 0.0], 0.0, 0.5, False, True), + ([0.9, 0.0], 0.0, 0.5, True, True), + ([0.0, 0.1], 1.5 * jnp.pi, 0.5, True, True), + ([0.1, 0.0], 0.5 * jnp.pi, 0.5, False, True), + ([0.1, 0.0], 0.5 * jnp.pi, 0.4, False, False), + ], +) +def test_target_found( + pos: List[float], + heading: float, + view_angle: float, + target_state: bool, + expected: bool, +) -> None: + target = TargetState( + pos=jnp.zeros((2,)), + found=target_state, + ) + + searcher = AgentState( + pos=jnp.array(pos), + heading=heading, + speed=0.0, + ) + + found = has_been_found(None, view_angle, target.pos, searcher) + reward = has_found_target(None, view_angle, searcher, target) + + assert found == expected + + if found and target_state: + assert reward == 0.0 + elif found and not target_state: + assert reward == 1.0 + elif not found: + assert reward == 0.0 diff --git a/jumanji/environments/swarms/search_and_rescue/types.py b/jumanji/environments/swarms/search_and_rescue/types.py index d24603d7e..8505f58c4 100644 --- a/jumanji/environments/swarms/search_and_rescue/types.py +++ b/jumanji/environments/swarms/search_and_rescue/types.py @@ -67,5 +67,5 @@ class Observation(NamedTuple): """ searcher_views: chex.Array # (num_searchers, num_vision) - target_remaining: chex.Array # () + targets_remaining: chex.Array # () time_remaining: chex.Array # () diff --git a/jumanji/environments/swarms/search_and_rescue/utils.py b/jumanji/environments/swarms/search_and_rescue/utils.py new file mode 100644 index 000000000..262ba569b --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/utils.py @@ -0,0 +1,88 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import chex +import jax +import jax.numpy as jnp +from esquilax.utils import shortest_vector + +from jumanji.environments.swarms.common.types import AgentState +from jumanji.environments.swarms.search_and_rescue.types import TargetState + + +def has_been_found( + _key: chex.PRNGKey, + searcher_view_angle: float, + target_pos: chex.Array, + searcher: AgentState, +) -> chex.Array: + """ + Returns True a target has been found. + + Return true if a target is within detection range + and within the view cone of a searcher. Used + to mark targets as found. + + Args: + _key: Dummy random key (required by Esquilax). + searcher_view_angle: View angle of searching agents + representing a fraction of pi from the agents heading. + target_pos: jax array (float) if shape (2,) representing + the position of the target. + searcher: Searcher agent state (i.e. position and heading). + + Returns: + is-found: `bool` True if the target had been found/detected. + """ + dx = shortest_vector(searcher.pos, target_pos) + phi = jnp.arctan2(dx[1], dx[0]) % (2 * jnp.pi) + dh = shortest_vector(phi, searcher.heading, 2 * jnp.pi) + searcher_view_angle = searcher_view_angle * jnp.pi + return (dh >= -searcher_view_angle) & (dh <= searcher_view_angle) + + +def has_found_target( + _key: chex.PRNGKey, + searcher_view_angle: float, + searcher: AgentState, + target: TargetState, +) -> chex.Array: + """ + Return +1.0 reward if the agent has detected an agent. + + Generate rewards for agents if a target is inside the + searchers view cone, and had not already been detected. + + Args: + _key: Dummy random key (required by Esquilax). + searcher_view_angle: View angle of searching agents + representing a fraction of pi from the agents heading. + searcher: State of the searching agent (i.e. the agent + position and heading) + target: State of the target (i.e. its position and + search status). + + Returns: + reward: +1.0 reward if the agent detects a new target. + """ + dx = shortest_vector(searcher.pos, target.pos) + phi = jnp.arctan2(dx[1], dx[0]) % (2 * jnp.pi) + dh = shortest_vector(phi, searcher.heading, 2 * jnp.pi) + searcher_view_angle = searcher_view_angle * jnp.pi + can_see = (dh >= -searcher_view_angle) & (dh <= searcher_view_angle) + return jax.lax.cond( + ~target.found & can_see, + lambda: 1.0, + lambda: 0.0, + ) From 072db189f011af260ade1a8d256c55390189c1e4 Mon Sep 17 00:00:00 2001 From: zombie-einstein <13398815+zombie-einstein@users.noreply.github.com> Date: Tue, 19 Nov 2024 20:27:54 +0000 Subject: [PATCH 07/19] refactor: PR fixes (#5) * refactor: Set -1.0 as default view value * refactor: Restructure tests * refactor: Pull out common functionality and fix formatting * refactor: Better function names --- docs/environments/search_and_rescue.md | 4 +- .../environments/swarms/common/test_common.py | 23 +++++---- jumanji/environments/swarms/common/updates.py | 43 +++++++++++++---- .../swarms/search_and_rescue/conftest.py | 40 ++++++++++++++++ .../swarms/search_and_rescue/env.py | 21 ++++---- .../swarms/search_and_rescue/env_test.py | 37 ++++---------- .../search_and_rescue/test_generator.py | 42 ++++++++++++++++ .../swarms/search_and_rescue/test_utils.py | 43 ++++------------- .../swarms/search_and_rescue/types.py | 16 +++++-- .../swarms/search_and_rescue/utils.py | 48 ++++++++++++------- 10 files changed, 206 insertions(+), 111 deletions(-) create mode 100644 jumanji/environments/swarms/search_and_rescue/conftest.py create mode 100644 jumanji/environments/swarms/search_and_rescue/test_generator.py diff --git a/docs/environments/search_and_rescue.md b/docs/environments/search_and_rescue.md index 65c98292b..8ee4ce783 100644 --- a/docs/environments/search_and_rescue.md +++ b/docs/environments/search_and_rescue.md @@ -30,10 +30,10 @@ space is a uniform space with unit dimensions, and wrapped at the boundaries. the observation array could be ``` - [1.0, 1.0, 0.5, 1.0, 1.0] + [-1.0, -1.0, 0.5, -1.0, -1.0] ``` - where `1.0` indicates there is no agents along that ray, and `0.5` is the normalised + where `-1.0` indicates there is no agents along that ray, and `0.5` is the normalised distance to the other agent. - `targets_remaining`: float in the range `[0, 1]`. The normalised number of targets remaining to be detected (i.e. 1.0 when no targets have been found). diff --git a/jumanji/environments/swarms/common/test_common.py b/jumanji/environments/swarms/common/test_common.py index 93128360b..a3a20367c 100644 --- a/jumanji/environments/swarms/common/test_common.py +++ b/jumanji/environments/swarms/common/test_common.py @@ -125,17 +125,24 @@ def test_state_update( assert jnp.allclose(new_state.speed, jnp.array([expected_speed])) +def test_view_reduction() -> None: + view_a = jnp.array([-1.0, -1.0, 0.2, 0.2, 0.5]) + view_b = jnp.array([-1.0, 0.2, -1.0, 0.5, 0.2]) + result = updates.view_reduction(view_a, view_b) + assert jnp.allclose(result, jnp.array([-1.0, 0.2, 0.2, 0.2, 0.2])) + + @pytest.mark.parametrize( "pos, view_angle, expected", [ - [[0.05, 0.0], 0.5, [1.0, 1.0, 0.5, 1.0, 1.0]], - [[0.0, 0.05], 0.5, [0.5, 1.0, 1.0, 1.0, 1.0]], - [[0.0, 0.95], 0.5, [1.0, 1.0, 1.0, 1.0, 0.5]], - [[0.95, 0.0], 0.5, [1.0, 1.0, 1.0, 1.0, 1.0]], - [[0.05, 0.0], 0.25, [1.0, 1.0, 0.5, 1.0, 1.0]], - [[0.0, 0.05], 0.25, [1.0, 1.0, 1.0, 1.0, 1.0]], - [[0.0, 0.95], 0.25, [1.0, 1.0, 1.0, 1.0, 1.0]], - [[0.01, 0.0], 0.5, [1.0, 1.0, 0.1, 1.0, 1.0]], + [[0.05, 0.0], 0.5, [-1.0, -1.0, 0.5, -1.0, -1.0]], + [[0.0, 0.05], 0.5, [0.5, -1.0, -1.0, -1.0, -1.0]], + [[0.0, 0.95], 0.5, [-1.0, -1.0, -1.0, -1.0, 0.5]], + [[0.95, 0.0], 0.5, [-1.0, -1.0, -1.0, -1.0, -1.0]], + [[0.05, 0.0], 0.25, [-1.0, -1.0, 0.5, -1.0, -1.0]], + [[0.0, 0.05], 0.25, [-1.0, -1.0, -1.0, -1.0, -1.0]], + [[0.0, 0.95], 0.25, [-1.0, -1.0, -1.0, -1.0, -1.0]], + [[0.01, 0.0], 0.5, [-1.0, -1.0, 0.1, -1.0, -1.0]], ], ) def test_view(pos: List[float], view_angle: float, expected: List[float]) -> None: diff --git a/jumanji/environments/swarms/common/updates.py b/jumanji/environments/swarms/common/updates.py index 665f44e34..671a6fce6 100644 --- a/jumanji/environments/swarms/common/updates.py +++ b/jumanji/environments/swarms/common/updates.py @@ -100,11 +100,36 @@ def update_state( ) +def view_reduction(view_a: chex.Array, view_b: chex.Array) -> chex.Array: + """ + Binary view reduction function. + + Handles reduction where a value of -1.0 indicates no + agent in view-range. Returns the min value of they + are both positive, but the max value if one or both of + the values is -1.0. + + Args: + view_a: View vector. + view_b: View vector. + + Returns: + jax array (float32): View vector indicating the + shortest distance to the nearest neighbour or + -1.0 if no agent is present along a ray. + """ + return jnp.where( + jnp.logical_or(view_a < 0.0, view_b < 0.0), + jnp.maximum(view_a, view_b), + jnp.minimum(view_a, view_b), + ) + + def view( _key: chex.PRNGKey, params: Tuple[float, float], - a: types.AgentState, - b: types.AgentState, + viewing_agent: types.AgentState, + viewed_agent: types.AgentState, *, n_view: int, i_range: float, @@ -115,16 +140,16 @@ def view( Simple view model where the agents view angle is subdivided into an array of values representing the distance from the agent along a rays from the agent, with rays evenly distributed. - across the agents field of view. The limit of vision is set at 1.0, - which is also the default value if no object is within range. + across the agents field of view. The limit of vision is set at 1.0. + The default value if no object is within range is -1.0. Currently, this model assumes the viewed objects are circular. Args: _key: Dummy JAX random key, required by esquilax API, but not used during the interaction. params: Tuple containing agent view angle and view-radius. - a: Viewing agent state. - b: State of agent being viewed. + viewing_agent: Viewing agent state. + viewed_agent: State of agent being viewed. n_view: Static number of view rays/subdivisions (i.e. how many cells the resulting array contains). i_range: Static agent view/interaction range. @@ -140,14 +165,14 @@ def view( n_view, endpoint=True, ) - dx = esquilax.utils.shortest_vector(a.pos, b.pos) + dx = esquilax.utils.shortest_vector(viewing_agent.pos, viewed_agent.pos) d = jnp.sqrt(jnp.sum(dx * dx)) / i_range phi = jnp.arctan2(dx[1], dx[0]) % (2 * jnp.pi) - dh = esquilax.utils.shortest_vector(phi, a.heading, 2 * jnp.pi) + dh = esquilax.utils.shortest_vector(phi, viewing_agent.heading, 2 * jnp.pi) angular_width = jnp.arctan2(radius, d) left = dh - angular_width right = dh + angular_width - obs = jnp.where(jnp.logical_and(left < rays, rays < right), d, 1.0) + obs = jnp.where(jnp.logical_and(left < rays, rays < right), d, -1.0) return obs diff --git a/jumanji/environments/swarms/search_and_rescue/conftest.py b/jumanji/environments/swarms/search_and_rescue/conftest.py new file mode 100644 index 000000000..ddb82e2b8 --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/conftest.py @@ -0,0 +1,40 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import chex +import jax.random +import pytest + +from jumanji.environments.swarms.search_and_rescue import SearchAndRescue + + +@pytest.fixture +def env() -> SearchAndRescue: + return SearchAndRescue( + searcher_vision_range=0.2, + target_contact_range=0.05, + num_vision=11, + agent_radius=0.05, + searcher_max_rotate=0.2, + searcher_max_accelerate=0.01, + searcher_min_speed=0.01, + searcher_max_speed=0.05, + searcher_view_angle=0.5, + max_steps=25, + ) + + +@pytest.fixture +def key() -> chex.PRNGKey: + return jax.random.PRNGKey(101) diff --git a/jumanji/environments/swarms/search_and_rescue/env.py b/jumanji/environments/swarms/search_and_rescue/env.py index 12ac5a341..09392d71d 100644 --- a/jumanji/environments/swarms/search_and_rescue/env.py +++ b/jumanji/environments/swarms/search_and_rescue/env.py @@ -24,7 +24,7 @@ from jumanji import specs from jumanji.env import Environment from jumanji.environments.swarms.common.types import AgentParams -from jumanji.environments.swarms.common.updates import update_state, view +from jumanji.environments.swarms.common.updates import update_state, view, view_reduction from jumanji.environments.swarms.search_and_rescue import utils from jumanji.environments.swarms.search_and_rescue.dynamics import RandomWalk, TargetDynamics from jumanji.environments.swarms.search_and_rescue.generator import Generator, RandomGenerator @@ -47,6 +47,9 @@ class SearchAndRescue(Environment): - observation: `Observation` searcher_views: jax array (float) of shape (num_searchers, num_vision) individual local views of positions of other searching agents. + Each entry in the view indicates the distant to the nearest neighbour + along a ray from the agent, and is -1.0 if no agent is in range + along the ray. targets_remaining: (float) Number of targets remaining to be found from the total scaled to the range [0, 1] (i.e. a value of 1.0 indicates all the targets are still to be found). @@ -216,15 +219,17 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser state: Updated searcher and target positions and velocities. timestep: Transition timestep with individual agent local observations. """ - searchers = update_state(state.key, self.searcher_params, state.searchers, actions) + # Note: only one new key is needed for the targets, as all other + # keys are just dummy values required by Esquilax key, target_key = jax.random.split(state.key, num=2) + searchers = update_state(key, self.searcher_params, state.searchers, actions) # Ensure target positions are wrapped target_pos = self._target_dynamics(target_key, state.targets.pos) % 1.0 # Grant searchers rewards if in range and not already detected # spatial maps the has_found_target function over all pair of targets and # searchers within range of each other and sums rewards per agent. rewards = spatial( - utils.has_found_target, + utils.reward_if_found_target, reduction=jnp.add, default=0.0, i_range=self.target_contact_range, @@ -240,7 +245,7 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser # spatial maps the has_been_found function over all pair of targets and # searchers within range of each other targets_found = spatial( - utils.has_been_found, + utils.target_has_been_found, reduction=jnp.logical_or, default=False, i_range=self.target_contact_range, @@ -262,7 +267,7 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser ) observation = self._state_to_observation(state) timestep = jax.lax.cond( - state.step >= self.max_steps | jnp.all(targets_found), + jnp.logical_or(state.step >= self.max_steps, jnp.all(targets_found)), termination, transition, rewards, @@ -273,8 +278,8 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser def _state_to_observation(self, state: State) -> Observation: searcher_views = spatial( view, - reduction=jnp.minimum, - default=jnp.ones((self.num_vision,)), + reduction=view_reduction, + default=-jnp.ones((self.num_vision,)), include_self=False, i_range=self.searcher_vision_range, )( @@ -306,7 +311,7 @@ def observation_spec(self) -> specs.Spec[Observation]: """ searcher_views = specs.BoundedArray( shape=(self.generator.num_searchers, self.num_vision), - minimum=0.0, + minimum=-1.0, maximum=1.0, dtype=float, name="searcher_views", diff --git a/jumanji/environments/swarms/search_and_rescue/env_test.py b/jumanji/environments/swarms/search_and_rescue/env_test.py index 92bb34144..edd78785c 100644 --- a/jumanji/environments/swarms/search_and_rescue/env_test.py +++ b/jumanji/environments/swarms/search_and_rescue/env_test.py @@ -32,34 +32,13 @@ from jumanji.testing.env_not_smoke import check_env_does_not_smoke, check_env_specs_does_not_smoke from jumanji.types import StepType, TimeStep -SEARCHER_VISION_RANGE = 0.2 -TARGET_CONTACT_RANGE = 0.05 -AGENT_RADIUS = 0.05 - - -@pytest.fixture -def env() -> SearchAndRescue: - return SearchAndRescue( - searcher_vision_range=SEARCHER_VISION_RANGE, - target_contact_range=TARGET_CONTACT_RANGE, - num_vision=11, - agent_radius=AGENT_RADIUS, - searcher_max_rotate=0.2, - searcher_max_accelerate=0.01, - searcher_min_speed=0.01, - searcher_max_speed=0.05, - searcher_view_angle=0.5, - max_steps=25, - ) - -def test_env_init(env: SearchAndRescue) -> None: +def test_env_init(env: SearchAndRescue, key: chex.PRNGKey) -> None: """ Check newly initialised state has expected array shapes and initial timestep. """ - k = jax.random.PRNGKey(101) - state, timestep = env.reset(k) + state, timestep = env.reset(key) assert isinstance(state, State) assert isinstance(state.searchers, AgentState) @@ -83,13 +62,12 @@ def test_env_init(env: SearchAndRescue) -> None: assert timestep.step_type == StepType.FIRST -def test_env_step(env: SearchAndRescue) -> None: +def test_env_step(env: SearchAndRescue, key: chex.PRNGKey) -> None: """ Run several steps of the environment with random actions and check states (i.e. positions, heading, speeds) all fall inside expected ranges. """ - key = jax.random.PRNGKey(101) n_steps = 22 def step( @@ -172,6 +150,7 @@ def test_env_specs_do_not_smoke(env: SearchAndRescue) -> None: ) def test_searcher_view( env: SearchAndRescue, + key: chex.PRNGKey, searcher_positions: List[List[float]], searcher_headings: List[float], view_updates: List[Tuple[int, int, float]], @@ -190,14 +169,14 @@ def test_searcher_view( pos=searcher_positions, heading=searcher_headings, speed=searcher_speed ), targets=TargetState(pos=jnp.zeros((1, 2)), found=jnp.zeros((1, 2), dtype=bool)), - key=jax.random.PRNGKey(101), + key=key, ) obs = env._state_to_observation(state) assert isinstance(obs, Observation) - expected = jnp.ones((searcher_headings.shape[0], env.num_vision)) + expected = jnp.full((searcher_headings.shape[0], env.num_vision), -1.0) for i, idx, val in view_updates: expected = expected.at[i, idx].set(val) @@ -205,7 +184,7 @@ def test_searcher_view( assert jnp.all(jnp.isclose(obs.searcher_views, expected)) -def test_target_detection(env: SearchAndRescue) -> None: +def test_target_detection(env: SearchAndRescue, key: chex.PRNGKey) -> None: # Keep targets in one location env._target_dynamics = RandomWalk(step_size=0.0) @@ -215,7 +194,7 @@ def test_target_detection(env: SearchAndRescue) -> None: pos=jnp.array([[0.5, 0.5]]), heading=jnp.array([jnp.pi]), speed=jnp.array([0.0]) ), targets=TargetState(pos=jnp.array([[0.54, 0.5]]), found=jnp.array([False])), - key=jax.random.PRNGKey(101), + key=key, ) state, timestep = env.step(state, jnp.zeros((1, 2))) assert not state.targets.found[0] diff --git a/jumanji/environments/swarms/search_and_rescue/test_generator.py b/jumanji/environments/swarms/search_and_rescue/test_generator.py new file mode 100644 index 000000000..de0a7dfc6 --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/test_generator.py @@ -0,0 +1,42 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import chex +import jax.numpy as jnp + +from jumanji.environments.swarms.common.types import AgentParams +from jumanji.environments.swarms.search_and_rescue.generator import Generator, RandomGenerator +from jumanji.environments.swarms.search_and_rescue.types import State + + +def test_random_generator(key: chex.PRNGKey) -> None: + params = AgentParams( + max_rotate=0.5, + max_accelerate=0.01, + min_speed=0.01, + max_speed=0.05, + view_angle=0.5, + ) + generator = RandomGenerator(num_searchers=100, num_targets=101) + + assert isinstance(generator, Generator) + + state = generator(key, params) + + assert isinstance(state, State) + assert state.searchers.pos.shape == (generator.num_searchers, 2) + assert jnp.all(0.0 <= state.searchers.pos) and jnp.all(state.searchers.pos <= 1.0) + assert state.targets.pos.shape == (generator.num_targets, 2) + assert jnp.all(0.0 <= state.targets.pos) and jnp.all(state.targets.pos <= 1.0) + assert not jnp.any(state.targets.found) + assert state.step == 0 diff --git a/jumanji/environments/swarms/search_and_rescue/test_utils.py b/jumanji/environments/swarms/search_and_rescue/test_utils.py index c2f18b9f2..f5f419b52 100644 --- a/jumanji/environments/swarms/search_and_rescue/test_utils.py +++ b/jumanji/environments/swarms/search_and_rescue/test_utils.py @@ -14,44 +14,21 @@ from typing import List -import jax +import chex import jax.numpy as jnp import pytest -from jumanji.environments.swarms.common.types import AgentParams, AgentState +from jumanji.environments.swarms.common.types import AgentState from jumanji.environments.swarms.search_and_rescue.dynamics import RandomWalk, TargetDynamics -from jumanji.environments.swarms.search_and_rescue.generator import Generator, RandomGenerator -from jumanji.environments.swarms.search_and_rescue.types import State, TargetState -from jumanji.environments.swarms.search_and_rescue.utils import has_been_found, has_found_target - - -def test_random_generator() -> None: - key = jax.random.PRNGKey(101) - params = AgentParams( - max_rotate=0.5, - max_accelerate=0.01, - min_speed=0.01, - max_speed=0.05, - view_angle=0.5, - ) - generator = RandomGenerator(num_searchers=100, num_targets=101) - - assert isinstance(generator, Generator) - - state = generator(key, params) - - assert isinstance(state, State) - assert state.searchers.pos.shape == (generator.num_searchers, 2) - assert jnp.all(0.0 <= state.searchers.pos) and jnp.all(state.searchers.pos <= 1.0) - assert state.targets.pos.shape == (generator.num_targets, 2) - assert jnp.all(0.0 <= state.targets.pos) and jnp.all(state.targets.pos <= 1.0) - assert not jnp.any(state.targets.found) - assert state.step == 0 +from jumanji.environments.swarms.search_and_rescue.types import TargetState +from jumanji.environments.swarms.search_and_rescue.utils import ( + reward_if_found_target, + target_has_been_found, +) -def test_random_walk_dynamics() -> None: +def test_random_walk_dynamics(key: chex.PRNGKey) -> None: n_targets = 50 - key = jax.random.PRNGKey(101) s0 = jnp.full((n_targets, 2), 0.5) dynamics = RandomWalk(0.1) @@ -94,8 +71,8 @@ def test_target_found( speed=0.0, ) - found = has_been_found(None, view_angle, target.pos, searcher) - reward = has_found_target(None, view_angle, searcher, target) + found = target_has_been_found(None, view_angle, target.pos, searcher) + reward = reward_if_found_target(None, view_angle, searcher, target) assert found == expected diff --git a/jumanji/environments/swarms/search_and_rescue/types.py b/jumanji/environments/swarms/search_and_rescue/types.py index 8505f58c4..cb924bd60 100644 --- a/jumanji/environments/swarms/search_and_rescue/types.py +++ b/jumanji/environments/swarms/search_and_rescue/types.py @@ -25,8 +25,16 @@ @dataclass class TargetState: - pos: chex.Array - found: chex.Array + """ + The state for the rescue targets. + + pos: 2d position of the target agents + found: Boolean flag indicating if the + target has been located by a searcher. + """ + + pos: chex.Array # (num_targets, 2) + found: chex.Array # (num_targets,) @dataclass @@ -59,10 +67,10 @@ class Observation(NamedTuple): `num_vision = 5` then the observation array could be ``` - [1.0, 1.0, 0.5, 1.0, 1.0] + [-1.0, -1.0, 0.5, -1.0, -1.0] ``` - where `1.0` indicates there is no agents along that ray, + where `-1.0` indicates there is no agents along that ray, and `0.5` is the normalised distance to the other agent. """ diff --git a/jumanji/environments/swarms/search_and_rescue/utils.py b/jumanji/environments/swarms/search_and_rescue/utils.py index 262ba569b..d359cc75b 100644 --- a/jumanji/environments/swarms/search_and_rescue/utils.py +++ b/jumanji/environments/swarms/search_and_rescue/utils.py @@ -13,7 +13,6 @@ # limitations under the License. import chex -import jax import jax.numpy as jnp from esquilax.utils import shortest_vector @@ -21,7 +20,32 @@ from jumanji.environments.swarms.search_and_rescue.types import TargetState -def has_been_found( +def _check_target_in_view( + searcher_pos: chex.Array, + target_pos: chex.Array, + searcher_heading: chex.Array, + searcher_view_angle: float, +) -> chex.Array: + """ + Check if a target is inside the view-cone of a searcher. + + Args: + searcher_pos: Searcher position + target_pos: Target position + searcher_heading: Searcher heading angle + searcher_view_angle: Searcher view angle + + Returns: + bool: Flag indicating if a target is within view. + """ + dx = shortest_vector(searcher_pos, target_pos) + phi = jnp.arctan2(dx[1], dx[0]) % (2 * jnp.pi) + dh = shortest_vector(phi, searcher_heading, 2 * jnp.pi) + searcher_view_angle = searcher_view_angle * jnp.pi + return (dh >= -searcher_view_angle) & (dh <= searcher_view_angle) + + +def target_has_been_found( _key: chex.PRNGKey, searcher_view_angle: float, target_pos: chex.Array, @@ -45,14 +69,10 @@ def has_been_found( Returns: is-found: `bool` True if the target had been found/detected. """ - dx = shortest_vector(searcher.pos, target_pos) - phi = jnp.arctan2(dx[1], dx[0]) % (2 * jnp.pi) - dh = shortest_vector(phi, searcher.heading, 2 * jnp.pi) - searcher_view_angle = searcher_view_angle * jnp.pi - return (dh >= -searcher_view_angle) & (dh <= searcher_view_angle) + return _check_target_in_view(searcher.pos, target_pos, searcher.heading, searcher_view_angle) -def has_found_target( +def reward_if_found_target( _key: chex.PRNGKey, searcher_view_angle: float, searcher: AgentState, @@ -76,13 +96,5 @@ def has_found_target( Returns: reward: +1.0 reward if the agent detects a new target. """ - dx = shortest_vector(searcher.pos, target.pos) - phi = jnp.arctan2(dx[1], dx[0]) % (2 * jnp.pi) - dh = shortest_vector(phi, searcher.heading, 2 * jnp.pi) - searcher_view_angle = searcher_view_angle * jnp.pi - can_see = (dh >= -searcher_view_angle) & (dh <= searcher_view_angle) - return jax.lax.cond( - ~target.found & can_see, - lambda: 1.0, - lambda: 0.0, - ) + can_see = _check_target_in_view(searcher.pos, target.pos, searcher.heading, searcher_view_angle) + return (~target.found & can_see).astype(float) From 162a74d87706c8b18de31640007f58c33707c0e0 Mon Sep 17 00:00:00 2001 From: zombie-einstein <13398815+zombie-einstein@users.noreply.github.com> Date: Tue, 19 Nov 2024 23:23:55 +0000 Subject: [PATCH 08/19] feat: Allow variable environment dimensions (#6) --- .../environments/swarms/common/test_common.py | 67 +++++++++++-------- jumanji/environments/swarms/common/updates.py | 19 ++++-- jumanji/environments/swarms/common/viewer.py | 9 ++- .../swarms/search_and_rescue/env.py | 18 ++++- .../swarms/search_and_rescue/env_test.py | 26 +++++-- .../swarms/search_and_rescue/generator.py | 12 +++- .../search_and_rescue/test_generator.py | 10 +-- .../swarms/search_and_rescue/test_utils.py | 27 ++++---- .../swarms/search_and_rescue/utils.py | 18 ++++- .../swarms/search_and_rescue/viewer.py | 8 ++- requirements/requirements.txt | 2 +- 11 files changed, 144 insertions(+), 72 deletions(-) diff --git a/jumanji/environments/swarms/common/test_common.py b/jumanji/environments/swarms/common/test_common.py index a3a20367c..cc3ebd67e 100644 --- a/jumanji/environments/swarms/common/test_common.py +++ b/jumanji/environments/swarms/common/test_common.py @@ -72,30 +72,38 @@ def test_velocity_update( @pytest.mark.parametrize( - "pos, heading, speed, expected", + "pos, heading, speed, expected, env_size", [ - [[0.0, 0.5], 0.0, 0.1, [0.1, 0.5]], - [[0.0, 0.5], jnp.pi, 0.1, [0.9, 0.5]], - [[0.5, 0.0], 0.5 * jnp.pi, 0.1, [0.5, 0.1]], - [[0.5, 0.0], 1.5 * jnp.pi, 0.1, [0.5, 0.9]], + [[0.0, 0.5], 0.0, 0.1, [0.1, 0.5], 1.0], + [[0.0, 0.5], jnp.pi, 0.1, [0.9, 0.5], 1.0], + [[0.5, 0.0], 0.5 * jnp.pi, 0.1, [0.5, 0.1], 1.0], + [[0.5, 0.0], 1.5 * jnp.pi, 0.1, [0.5, 0.9], 1.0], + [[0.4, 0.2], 0.0, 0.2, [0.1, 0.2], 0.5], + [[0.1, 0.2], jnp.pi, 0.2, [0.4, 0.2], 0.5], + [[0.2, 0.4], 0.5 * jnp.pi, 0.2, [0.2, 0.1], 0.5], + [[0.2, 0.1], 1.5 * jnp.pi, 0.2, [0.2, 0.4], 0.5], ], ) -def test_move(pos: List[float], heading: float, speed: float, expected: List[float]) -> None: +def test_move( + pos: List[float], heading: float, speed: float, expected: List[float], env_size: float +) -> None: pos = jnp.array(pos) - new_pos = updates.move(pos, heading, speed) + new_pos = updates.move(pos, heading, speed, env_size) assert jnp.allclose(new_pos, jnp.array(expected)) @pytest.mark.parametrize( - "pos, heading, speed, actions, expected_pos, expected_heading, expected_speed", + "pos, heading, speed, actions, expected_pos, expected_heading, expected_speed, env_size", [ - [[0.0, 0.5], 0.0, 0.01, [0.0, 0.0], [0.01, 0.5], 0.0, 0.01], - [[0.5, 0.0], 0.0, 0.01, [1.0, 0.0], [0.5, 0.01], 0.5 * jnp.pi, 0.01], - [[0.5, 0.0], 0.0, 0.01, [-1.0, 0.0], [0.5, 0.99], 1.5 * jnp.pi, 0.01], - [[0.0, 0.5], 0.0, 0.01, [0.0, 1.0], [0.02, 0.5], 0.0, 0.02], - [[0.0, 0.5], 0.0, 0.01, [0.0, -1.0], [0.01, 0.5], 0.0, 0.01], - [[0.0, 0.5], 0.0, 0.05, [0.0, 1.0], [0.05, 0.5], 0.0, 0.05], + [[0.0, 0.5], 0.0, 0.01, [0.0, 0.0], [0.01, 0.5], 0.0, 0.01, 1.0], + [[0.5, 0.0], 0.0, 0.01, [1.0, 0.0], [0.5, 0.01], 0.5 * jnp.pi, 0.01, 1.0], + [[0.5, 0.0], 0.0, 0.01, [-1.0, 0.0], [0.5, 0.99], 1.5 * jnp.pi, 0.01, 1.0], + [[0.0, 0.5], 0.0, 0.01, [0.0, 1.0], [0.02, 0.5], 0.0, 0.02, 1.0], + [[0.0, 0.5], 0.0, 0.01, [0.0, -1.0], [0.01, 0.5], 0.0, 0.01, 1.0], + [[0.0, 0.5], 0.0, 0.05, [0.0, 1.0], [0.05, 0.5], 0.0, 0.05, 1.0], + [[0.495, 0.25], 0.0, 0.01, [0.0, 0.0], [0.005, 0.25], 0.0, 0.01, 0.5], + [[0.25, 0.005], 1.5 * jnp.pi, 0.01, [0.0, 0.0], [0.25, 0.495], 1.5 * jnp.pi, 0.01, 0.5], ], ) def test_state_update( @@ -107,6 +115,7 @@ def test_state_update( expected_pos: List[float], expected_heading: float, expected_speed: float, + env_size: float, ) -> None: key = jax.random.PRNGKey(101) @@ -117,7 +126,7 @@ def test_state_update( ) actions = jnp.array([actions]) - new_state = updates.update_state(key, params, state, actions) + new_state = updates.update_state(key, env_size, params, state, actions) assert isinstance(new_state, types.AgentState) assert jnp.allclose(new_state.pos, jnp.array([expected_pos])) @@ -133,19 +142,21 @@ def test_view_reduction() -> None: @pytest.mark.parametrize( - "pos, view_angle, expected", + "pos, view_angle, env_size, expected", [ - [[0.05, 0.0], 0.5, [-1.0, -1.0, 0.5, -1.0, -1.0]], - [[0.0, 0.05], 0.5, [0.5, -1.0, -1.0, -1.0, -1.0]], - [[0.0, 0.95], 0.5, [-1.0, -1.0, -1.0, -1.0, 0.5]], - [[0.95, 0.0], 0.5, [-1.0, -1.0, -1.0, -1.0, -1.0]], - [[0.05, 0.0], 0.25, [-1.0, -1.0, 0.5, -1.0, -1.0]], - [[0.0, 0.05], 0.25, [-1.0, -1.0, -1.0, -1.0, -1.0]], - [[0.0, 0.95], 0.25, [-1.0, -1.0, -1.0, -1.0, -1.0]], - [[0.01, 0.0], 0.5, [-1.0, -1.0, 0.1, -1.0, -1.0]], + [[0.05, 0.0], 0.5, 1.0, [-1.0, -1.0, 0.5, -1.0, -1.0]], + [[0.0, 0.05], 0.5, 1.0, [0.5, -1.0, -1.0, -1.0, -1.0]], + [[0.0, 0.95], 0.5, 1.0, [-1.0, -1.0, -1.0, -1.0, 0.5]], + [[0.95, 0.0], 0.5, 1.0, [-1.0, -1.0, -1.0, -1.0, -1.0]], + [[0.05, 0.0], 0.25, 1.0, [-1.0, -1.0, 0.5, -1.0, -1.0]], + [[0.0, 0.05], 0.25, 1.0, [-1.0, -1.0, -1.0, -1.0, -1.0]], + [[0.0, 0.95], 0.25, 1.0, [-1.0, -1.0, -1.0, -1.0, -1.0]], + [[0.01, 0.0], 0.5, 1.0, [-1.0, -1.0, 0.1, -1.0, -1.0]], + [[0.0, 0.45], 0.5, 1.0, [4.5, -1.0, -1.0, -1.0, -1.0]], + [[0.0, 0.45], 0.5, 0.5, [-1.0, -1.0, -1.0, -1.0, 0.5]], ], ) -def test_view(pos: List[float], view_angle: float, expected: List[float]) -> None: +def test_view(pos: List[float], view_angle: float, env_size: float, expected: List[float]) -> None: state_a = types.AgentState( pos=jnp.zeros((2,)), heading=0.0, @@ -158,13 +169,15 @@ def test_view(pos: List[float], view_angle: float, expected: List[float]) -> Non speed=0.0, ) - obs = updates.view(None, (view_angle, 0.02), state_a, state_b, n_view=5, i_range=0.1) + obs = updates.view( + None, (view_angle, 0.02), state_a, state_b, n_view=5, i_range=0.1, env_size=env_size + ) assert jnp.allclose(obs, jnp.array(expected)) def test_viewer_utils() -> None: f, ax = plt.subplots() - f, ax = viewer.format_plot(f, ax) + f, ax = viewer.format_plot(f, ax, (1.0, 1.0)) assert isinstance(f, matplotlib.figure.Figure) assert isinstance(ax, matplotlib.axes.Axes) diff --git a/jumanji/environments/swarms/common/updates.py b/jumanji/environments/swarms/common/updates.py index 671a6fce6..fc7d3f5c1 100644 --- a/jumanji/environments/swarms/common/updates.py +++ b/jumanji/environments/swarms/common/updates.py @@ -54,24 +54,26 @@ def update_velocity( return new_heading, new_speeds -def move(pos: chex.Array, heading: chex.Array, speed: chex.Array) -> chex.Array: +def move(pos: chex.Array, heading: chex.Array, speed: chex.Array, env_size: float) -> chex.Array: """ Get updated agent positions from current speed and heading Args: - pos: Agent position + pos: Agent position. heading: Agent heading (angle). - speed: Agent speed + speed: Agent speed. + env_size: Size of the environment. Returns: - jax array (float32): Updated agent position + jax array (float32): Updated agent position. """ d_pos = jnp.array([speed * jnp.cos(heading), speed * jnp.sin(heading)]) - return (pos + d_pos) % 1.0 + return (pos + d_pos) % env_size def update_state( key: chex.PRNGKey, + env_size: float, params: types.AgentParams, state: types.AgentState, actions: chex.Array, @@ -81,6 +83,7 @@ def update_state( Args: key: Dummy JAX random key. + env_size: Size of the environment. params: Agent parameters. state: Current agent states. actions: Agent actions, i.e. a 2D array of action for each agent. @@ -91,7 +94,7 @@ def update_state( """ actions = jnp.clip(actions, min=-1.0, max=1.0) headings, speeds = update_velocity(key, params, (actions, state)) - positions = jax.vmap(move)(state.pos, headings, speeds) + positions = jax.vmap(move, in_axes=(0, 0, 0, None))(state.pos, headings, speeds, env_size) return types.AgentState( pos=positions, @@ -133,6 +136,7 @@ def view( *, n_view: int, i_range: float, + env_size: float, ) -> chex.Array: """ Simple agent view model @@ -153,6 +157,7 @@ def view( n_view: Static number of view rays/subdivisions (i.e. how many cells the resulting array contains). i_range: Static agent view/interaction range. + env_size: Size of the environment. Returns: jax array (float32): 1D array representing the distance @@ -165,7 +170,7 @@ def view( n_view, endpoint=True, ) - dx = esquilax.utils.shortest_vector(viewing_agent.pos, viewed_agent.pos) + dx = esquilax.utils.shortest_vector(viewing_agent.pos, viewed_agent.pos, length=env_size) d = jnp.sqrt(jnp.sum(dx * dx)) / i_range phi = jnp.arctan2(dx[1], dx[0]) % (2 * jnp.pi) dh = esquilax.utils.shortest_vector(phi, viewing_agent.heading, 2 * jnp.pi) diff --git a/jumanji/environments/swarms/common/viewer.py b/jumanji/environments/swarms/common/viewer.py index 7a5f029b5..16bb197e6 100644 --- a/jumanji/environments/swarms/common/viewer.py +++ b/jumanji/environments/swarms/common/viewer.py @@ -45,12 +45,15 @@ def draw_agents(ax: Axes, agent_states: AgentState, color: str) -> Quiver: return q -def format_plot(fig: Figure, ax: Axes, border: float = 0.01) -> Tuple[Figure, Axes]: +def format_plot( + fig: Figure, ax: Axes, env_dims: Tuple[float, float], border: float = 0.01 +) -> Tuple[Figure, Axes]: """Format a flock/swarm plot, remove ticks and bound to the unit interval Args: fig: Matplotlib figure. ax: Matplotlib axes. + env_dims: Environment dimensions (i.e. its boundaries). border: Border padding to apply around plot. Returns: @@ -67,7 +70,7 @@ def format_plot(fig: Figure, ax: Axes, border: float = 0.01) -> Tuple[Figure, Ax ) ax.set_xticks([]) ax.set_yticks([]) - ax.set_xlim(0, 1) - ax.set_ylim(0, 1) + ax.set_xlim(0, env_dims[0]) + ax.set_ylim(0, env_dims[1]) return fig, ax diff --git a/jumanji/environments/swarms/search_and_rescue/env.py b/jumanji/environments/swarms/search_and_rescue/env.py index 09392d71d..4bb9b35bd 100644 --- a/jumanji/environments/swarms/search_and_rescue/env.py +++ b/jumanji/environments/swarms/search_and_rescue/env.py @@ -170,9 +170,12 @@ def __init__( view_angle=searcher_view_angle, ) self.max_steps = max_steps - self._viewer = viewer or SearchAndRescueViewer() self._target_dynamics = target_dynamics or RandomWalk(0.01) self.generator = generator or RandomGenerator(num_targets=20, num_searchers=10) + self._viewer = viewer or SearchAndRescueViewer() + # Needed to set environment boundaries for plots + if isinstance(self._viewer, SearchAndRescueViewer): + self._viewer.env_size = (self.generator.env_size, self.generator.env_size) super().__init__() def __repr__(self) -> str: @@ -186,6 +189,7 @@ def __repr__(self) -> str: f" - num vision: {self.num_vision}", f" - agent radius: {self.agent_radius}", f" - max steps: {self.max_steps}," + f" - env size: {self.generator.env_size}" f" - target dynamics: {self._target_dynamics.__class__.__name__}", f" - generator: {self.generator.__class__.__name__}", ] @@ -222,9 +226,11 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser # Note: only one new key is needed for the targets, as all other # keys are just dummy values required by Esquilax key, target_key = jax.random.split(state.key, num=2) - searchers = update_state(key, self.searcher_params, state.searchers, actions) + searchers = update_state( + key, self.generator.env_size, self.searcher_params, state.searchers, actions + ) # Ensure target positions are wrapped - target_pos = self._target_dynamics(target_key, state.targets.pos) % 1.0 + target_pos = self._target_dynamics(target_key, state.targets.pos) % self.generator.env_size # Grant searchers rewards if in range and not already detected # spatial maps the has_found_target function over all pair of targets and # searchers within range of each other and sums rewards per agent. @@ -233,6 +239,7 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser reduction=jnp.add, default=0.0, i_range=self.target_contact_range, + dims=self.generator.env_size, )( key, self.searcher_params.view_angle, @@ -240,6 +247,7 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser state.targets, pos=searchers.pos, pos_b=target_pos, + env_size=self.generator.env_size, ) # Mark targets as found if with contact range and view angle of a searcher # spatial maps the has_been_found function over all pair of targets and @@ -249,6 +257,7 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser reduction=jnp.logical_or, default=False, i_range=self.target_contact_range, + dims=self.generator.env_size, )( key, self.searcher_params.view_angle, @@ -256,6 +265,7 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser searchers, pos=target_pos, pos_b=searchers.pos, + env_size=self.generator.env_size, ) # Targets need to remain found if they already have been targets_found = jnp.logical_or(targets_found, state.targets.found) @@ -282,6 +292,7 @@ def _state_to_observation(self, state: State) -> Observation: default=-jnp.ones((self.num_vision,)), include_self=False, i_range=self.searcher_vision_range, + dims=self.generator.env_size, )( state.key, (self.searcher_params.view_angle, self.agent_radius), @@ -290,6 +301,7 @@ def _state_to_observation(self, state: State) -> Observation: pos=state.searchers.pos, n_view=self.num_vision, i_range=self.searcher_vision_range, + env_size=self.generator.env_size, ) return Observation( diff --git a/jumanji/environments/swarms/search_and_rescue/env_test.py b/jumanji/environments/swarms/search_and_rescue/env_test.py index edd78785c..03d7f0eca 100644 --- a/jumanji/environments/swarms/search_and_rescue/env_test.py +++ b/jumanji/environments/swarms/search_and_rescue/env_test.py @@ -62,13 +62,15 @@ def test_env_init(env: SearchAndRescue, key: chex.PRNGKey) -> None: assert timestep.step_type == StepType.FIRST -def test_env_step(env: SearchAndRescue, key: chex.PRNGKey) -> None: +@pytest.mark.parametrize("env_size", [1.0, 0.2]) +def test_env_step(env: SearchAndRescue, key: chex.PRNGKey, env_size: float) -> None: """ Run several steps of the environment with random actions and check states (i.e. positions, heading, speeds) all fall inside expected ranges. """ n_steps = 22 + env.generator.env_size = env_size def step( carry: Tuple[chex.PRNGKey, State], _: None @@ -89,7 +91,7 @@ def step( assert isinstance(state_history, State) assert state_history.searchers.pos.shape == (n_steps, env.generator.num_searchers, 2) - assert jnp.all((0.0 <= state_history.searchers.pos) & (state_history.searchers.pos <= 1.0)) + assert jnp.all((0.0 <= state_history.searchers.pos) & (state_history.searchers.pos <= env_size)) assert state_history.searchers.speed.shape == (n_steps, env.generator.num_searchers) assert jnp.all( (env.searcher_params.min_speed <= state_history.searchers.speed) @@ -101,7 +103,7 @@ def step( ) assert state_history.targets.pos.shape == (n_steps, env.generator.num_targets, 2) - assert jnp.all((0.0 <= state_history.targets.pos) & (state_history.targets.pos <= 1.0)) + assert jnp.all((0.0 <= state_history.targets.pos) & (state_history.targets.pos <= env_size)) def test_env_does_not_smoke(env: SearchAndRescue) -> None: @@ -122,28 +124,38 @@ def test_env_specs_do_not_smoke(env: SearchAndRescue) -> None: @pytest.mark.parametrize( - "searcher_positions, searcher_headings, view_updates", + "searcher_positions, searcher_headings, env_size, view_updates", [ # Both out of view range - ([[0.8, 0.5], [0.2, 0.5]], [jnp.pi, 0.0], []), + ([[0.8, 0.5], [0.2, 0.5]], [jnp.pi, 0.0], 1.0, []), # Both view each other - ([[0.25, 0.5], [0.2, 0.5]], [jnp.pi, 0.0], [(0, 5, 0.25), (1, 5, 0.25)]), + ([[0.25, 0.5], [0.2, 0.5]], [jnp.pi, 0.0], 1.0, [(0, 5, 0.25), (1, 5, 0.25)]), # One facing wrong direction ( [[0.25, 0.5], [0.2, 0.5]], [jnp.pi, jnp.pi], + 1.0, [(0, 5, 0.25)], ), # Only see closest neighbour ( [[0.35, 0.5], [0.25, 0.5], [0.2, 0.5]], [jnp.pi, 0.0, 0.0], + 1.0, [(0, 5, 0.5), (1, 5, 0.5), (2, 5, 0.25)], ), # Observed around wrapped edge ( [[0.025, 0.5], [0.975, 0.5]], [jnp.pi, 0.0], + 1.0, + [(0, 5, 0.25), (1, 5, 0.25)], + ), + # Observed around wrapped edge of smaller env + ( + [[0.025, 0.25], [0.475, 0.25]], + [jnp.pi, 0.0], + 0.5, [(0, 5, 0.25), (1, 5, 0.25)], ), ], @@ -153,12 +165,14 @@ def test_searcher_view( key: chex.PRNGKey, searcher_positions: List[List[float]], searcher_headings: List[float], + env_size: float, view_updates: List[Tuple[int, int, float]], ) -> None: """ Test view model generates expected array with different configurations of agents. """ + env.generator.env_size = env_size searcher_positions = jnp.array(searcher_positions) searcher_headings = jnp.array(searcher_headings) diff --git a/jumanji/environments/swarms/search_and_rescue/generator.py b/jumanji/environments/swarms/search_and_rescue/generator.py index 9de816086..3a3c85251 100644 --- a/jumanji/environments/swarms/search_and_rescue/generator.py +++ b/jumanji/environments/swarms/search_and_rescue/generator.py @@ -23,15 +23,17 @@ class Generator(abc.ABC): - def __init__(self, num_searchers: int, num_targets: int) -> None: + def __init__(self, num_searchers: int, num_targets: int, env_size: float = 1.0) -> None: """Interface for instance generation for the `SearchAndRescue` environment. Args: num_searchers: Number of searcher agents num_targets: Number of search targets + env_size: Size (dimensions of the environment), default 1.0. """ self.num_searchers = num_searchers self.num_targets = num_targets + self.env_size = env_size @abc.abstractmethod def __call__(self, key: chex.PRNGKey, searcher_params: AgentParams) -> State: @@ -60,7 +62,9 @@ def __call__(self, key: chex.PRNGKey, searcher_params: AgentParams) -> State: key, searcher_key, target_key = jax.random.split(key, num=3) k_pos, k_head, k_speed = jax.random.split(searcher_key, 3) - positions = jax.random.uniform(k_pos, (self.num_searchers, 2)) + positions = jax.random.uniform( + k_pos, (self.num_searchers, 2), minval=0.0, maxval=self.env_size + ) headings = jax.random.uniform( k_head, (self.num_searchers,), minval=0.0, maxval=2.0 * jnp.pi ) @@ -76,7 +80,9 @@ def __call__(self, key: chex.PRNGKey, searcher_params: AgentParams) -> State: heading=headings, ) - target_pos = jax.random.uniform(target_key, (self.num_targets, 2)) + target_pos = jax.random.uniform( + target_key, (self.num_targets, 2), minval=0.0, maxval=self.env_size + ) state = State( searchers=searcher_state, diff --git a/jumanji/environments/swarms/search_and_rescue/test_generator.py b/jumanji/environments/swarms/search_and_rescue/test_generator.py index de0a7dfc6..e6552359c 100644 --- a/jumanji/environments/swarms/search_and_rescue/test_generator.py +++ b/jumanji/environments/swarms/search_and_rescue/test_generator.py @@ -13,13 +13,15 @@ # limitations under the License. import chex import jax.numpy as jnp +import pytest from jumanji.environments.swarms.common.types import AgentParams from jumanji.environments.swarms.search_and_rescue.generator import Generator, RandomGenerator from jumanji.environments.swarms.search_and_rescue.types import State -def test_random_generator(key: chex.PRNGKey) -> None: +@pytest.mark.parametrize("env_size", [1.0, 0.5]) +def test_random_generator(key: chex.PRNGKey, env_size: float) -> None: params = AgentParams( max_rotate=0.5, max_accelerate=0.01, @@ -27,7 +29,7 @@ def test_random_generator(key: chex.PRNGKey) -> None: max_speed=0.05, view_angle=0.5, ) - generator = RandomGenerator(num_searchers=100, num_targets=101) + generator = RandomGenerator(num_searchers=100, num_targets=101, env_size=env_size) assert isinstance(generator, Generator) @@ -35,8 +37,8 @@ def test_random_generator(key: chex.PRNGKey) -> None: assert isinstance(state, State) assert state.searchers.pos.shape == (generator.num_searchers, 2) - assert jnp.all(0.0 <= state.searchers.pos) and jnp.all(state.searchers.pos <= 1.0) + assert jnp.all(0.0 <= state.searchers.pos) and jnp.all(state.searchers.pos <= env_size) assert state.targets.pos.shape == (generator.num_targets, 2) - assert jnp.all(0.0 <= state.targets.pos) and jnp.all(state.targets.pos <= 1.0) + assert jnp.all(0.0 <= state.targets.pos) and jnp.all(state.targets.pos <= env_size) assert not jnp.any(state.targets.found) assert state.step == 0 diff --git a/jumanji/environments/swarms/search_and_rescue/test_utils.py b/jumanji/environments/swarms/search_and_rescue/test_utils.py index f5f419b52..01ff5a9cf 100644 --- a/jumanji/environments/swarms/search_and_rescue/test_utils.py +++ b/jumanji/environments/swarms/search_and_rescue/test_utils.py @@ -40,17 +40,19 @@ def test_random_walk_dynamics(key: chex.PRNGKey) -> None: @pytest.mark.parametrize( - "pos, heading, view_angle, target_state, expected", + "pos, heading, view_angle, target_state, expected, env_size", [ - ([0.1, 0.0], 0.0, 0.5, False, False), - ([0.1, 0.0], jnp.pi, 0.5, False, True), - ([0.1, 0.0], jnp.pi, 0.5, True, True), - ([0.9, 0.0], jnp.pi, 0.5, False, False), - ([0.9, 0.0], 0.0, 0.5, False, True), - ([0.9, 0.0], 0.0, 0.5, True, True), - ([0.0, 0.1], 1.5 * jnp.pi, 0.5, True, True), - ([0.1, 0.0], 0.5 * jnp.pi, 0.5, False, True), - ([0.1, 0.0], 0.5 * jnp.pi, 0.4, False, False), + ([0.1, 0.0], 0.0, 0.5, False, False, 1.0), + ([0.1, 0.0], jnp.pi, 0.5, False, True, 1.0), + ([0.1, 0.0], jnp.pi, 0.5, True, True, 1.0), + ([0.9, 0.0], jnp.pi, 0.5, False, False, 1.0), + ([0.9, 0.0], 0.0, 0.5, False, True, 1.0), + ([0.9, 0.0], 0.0, 0.5, True, True, 1.0), + ([0.0, 0.1], 1.5 * jnp.pi, 0.5, True, True, 1.0), + ([0.1, 0.0], 0.5 * jnp.pi, 0.5, False, True, 1.0), + ([0.1, 0.0], 0.5 * jnp.pi, 0.4, False, False, 1.0), + ([0.4, 0.0], 0.0, 0.5, False, False, 1.0), + ([0.4, 0.0], 0.0, 0.5, False, True, 0.5), ], ) def test_target_found( @@ -59,6 +61,7 @@ def test_target_found( view_angle: float, target_state: bool, expected: bool, + env_size: float, ) -> None: target = TargetState( pos=jnp.zeros((2,)), @@ -71,8 +74,8 @@ def test_target_found( speed=0.0, ) - found = target_has_been_found(None, view_angle, target.pos, searcher) - reward = reward_if_found_target(None, view_angle, searcher, target) + found = target_has_been_found(None, view_angle, target.pos, searcher, env_size=env_size) + reward = reward_if_found_target(None, view_angle, searcher, target, env_size=env_size) assert found == expected diff --git a/jumanji/environments/swarms/search_and_rescue/utils.py b/jumanji/environments/swarms/search_and_rescue/utils.py index d359cc75b..ea477c270 100644 --- a/jumanji/environments/swarms/search_and_rescue/utils.py +++ b/jumanji/environments/swarms/search_and_rescue/utils.py @@ -25,6 +25,7 @@ def _check_target_in_view( target_pos: chex.Array, searcher_heading: chex.Array, searcher_view_angle: float, + env_size: float, ) -> chex.Array: """ Check if a target is inside the view-cone of a searcher. @@ -34,11 +35,12 @@ def _check_target_in_view( target_pos: Target position searcher_heading: Searcher heading angle searcher_view_angle: Searcher view angle + env_size: Size of the environment Returns: bool: Flag indicating if a target is within view. """ - dx = shortest_vector(searcher_pos, target_pos) + dx = shortest_vector(searcher_pos, target_pos, length=env_size) phi = jnp.arctan2(dx[1], dx[0]) % (2 * jnp.pi) dh = shortest_vector(phi, searcher_heading, 2 * jnp.pi) searcher_view_angle = searcher_view_angle * jnp.pi @@ -50,6 +52,8 @@ def target_has_been_found( searcher_view_angle: float, target_pos: chex.Array, searcher: AgentState, + *, + env_size: float, ) -> chex.Array: """ Returns True a target has been found. @@ -65,11 +69,14 @@ def target_has_been_found( target_pos: jax array (float) if shape (2,) representing the position of the target. searcher: Searcher agent state (i.e. position and heading). + env_size: size of the environment. Returns: is-found: `bool` True if the target had been found/detected. """ - return _check_target_in_view(searcher.pos, target_pos, searcher.heading, searcher_view_angle) + return _check_target_in_view( + searcher.pos, target_pos, searcher.heading, searcher_view_angle, env_size + ) def reward_if_found_target( @@ -77,6 +84,8 @@ def reward_if_found_target( searcher_view_angle: float, searcher: AgentState, target: TargetState, + *, + env_size: float, ) -> chex.Array: """ Return +1.0 reward if the agent has detected an agent. @@ -92,9 +101,12 @@ def reward_if_found_target( position and heading) target: State of the target (i.e. its position and search status). + env_size: size of the environment. Returns: reward: +1.0 reward if the agent detects a new target. """ - can_see = _check_target_in_view(searcher.pos, target.pos, searcher.heading, searcher_view_angle) + can_see = _check_target_in_view( + searcher.pos, target.pos, searcher.heading, searcher_view_angle, env_size + ) return (~target.found & can_see).astype(float) diff --git a/jumanji/environments/swarms/search_and_rescue/viewer.py b/jumanji/environments/swarms/search_and_rescue/viewer.py index 72470e687..5f2257292 100644 --- a/jumanji/environments/swarms/search_and_rescue/viewer.py +++ b/jumanji/environments/swarms/search_and_rescue/viewer.py @@ -27,7 +27,7 @@ from jumanji.viewer import Viewer -class SearchAndRescueViewer(Viewer): +class SearchAndRescueViewer(Viewer[State]): def __init__( self, figure_name: str = "SearchAndRescue", @@ -35,6 +35,7 @@ def __init__( searcher_color: str = "blue", target_found_color: str = "green", target_lost_color: str = "red", + env_size: Tuple[float, float] = (1.0, 1.0), ) -> None: """Viewer for the `SearchAndRescue` environment. @@ -47,6 +48,7 @@ def __init__( self.searcher_color = searcher_color self.target_colors = np.array([target_lost_color, target_found_color]) self._animation: Optional[matplotlib.animation.Animation] = None + self.env_size = env_size def render(self, state: State) -> None: """Render a frame of the environment for a given state using matplotlib. @@ -77,7 +79,7 @@ def animate( if not states: raise ValueError(f"The states argument has to be non-empty, got {states}.") fig, ax = plt.subplots(num=f"{self._figure_name}Anim", figsize=self._figure_size) - fig, ax = format_plot(fig, ax) + fig, ax = format_plot(fig, ax, self.env_size) searcher_quiver = draw_agents(ax, states[0].searchers, self.searcher_color) target_scatter = ax.scatter( @@ -137,7 +139,7 @@ def _get_fig_ax(self) -> Tuple[plt.Figure, plt.Axes]: fig.show() ax = fig.add_subplot() - fig, ax = format_plot(fig, ax) + fig, ax = format_plot(fig, ax, self.env_size) return fig, ax def _update_display(self, fig: plt.Figure) -> None: diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 2e4054474..0b6cdbd76 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,6 +1,6 @@ chex>=0.1.3 dm-env>=1.5 -esquilax>=1.0.2 +esquilax>=1.0.3 gym>=0.22.0 huggingface-hub jax>=0.2.26 From 6322f61ccb15d4f73bc9deba0beff7a2ab974725 Mon Sep 17 00:00:00 2001 From: zombie-einstein <13398815+zombie-einstein@users.noreply.github.com> Date: Sat, 23 Nov 2024 01:02:38 +0000 Subject: [PATCH 09/19] fix: Locate targets in single pass (#8) * Set plot range in viewer only * Detect targets in a single pass --- .../swarms/search_and_rescue/env.py | 39 +++-------- .../swarms/search_and_rescue/env_test.py | 68 ++++++++++++++++++- .../swarms/search_and_rescue/test_utils.py | 31 ++++----- .../swarms/search_and_rescue/utils.py | 52 ++++---------- .../swarms/search_and_rescue/viewer.py | 4 ++ 5 files changed, 109 insertions(+), 85 deletions(-) diff --git a/jumanji/environments/swarms/search_and_rescue/env.py b/jumanji/environments/swarms/search_and_rescue/env.py index 4bb9b35bd..a495f621e 100644 --- a/jumanji/environments/swarms/search_and_rescue/env.py +++ b/jumanji/environments/swarms/search_and_rescue/env.py @@ -173,9 +173,6 @@ def __init__( self._target_dynamics = target_dynamics or RandomWalk(0.01) self.generator = generator or RandomGenerator(num_targets=20, num_searchers=10) self._viewer = viewer or SearchAndRescueViewer() - # Needed to set environment boundaries for plots - if isinstance(self._viewer, SearchAndRescueViewer): - self._viewer.env_size = (self.generator.env_size, self.generator.env_size) super().__init__() def __repr__(self) -> str: @@ -231,44 +228,30 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser ) # Ensure target positions are wrapped target_pos = self._target_dynamics(target_key, state.targets.pos) % self.generator.env_size - # Grant searchers rewards if in range and not already detected - # spatial maps the has_found_target function over all pair of targets and - # searchers within range of each other and sums rewards per agent. - rewards = spatial( - utils.reward_if_found_target, + # Searchers return an array of flags of any targets they are in range of, + # and that have not already been locating, result shape here is (n-searcher, n-targets) + targets_found = spatial( + utils.searcher_detect_targets, reduction=jnp.add, - default=0.0, + default=jnp.zeros((target_pos.shape[0],), dtype=bool), i_range=self.target_contact_range, dims=self.generator.env_size, )( key, self.searcher_params.view_angle, searchers, - state.targets, + (jnp.arange(target_pos.shape[0]), state.targets), pos=searchers.pos, pos_b=target_pos, env_size=self.generator.env_size, + n_targets=target_pos.shape[0], ) - # Mark targets as found if with contact range and view angle of a searcher - # spatial maps the has_been_found function over all pair of targets and - # searchers within range of each other - targets_found = spatial( - utils.target_has_been_found, - reduction=jnp.logical_or, - default=False, - i_range=self.target_contact_range, - dims=self.generator.env_size, - )( - key, - self.searcher_params.view_angle, - state.targets.pos, - searchers, - pos=target_pos, - pos_b=searchers.pos, - env_size=self.generator.env_size, - ) + + rewards = jnp.sum(targets_found, axis=1) + targets_found = jnp.any(targets_found, axis=0) # Targets need to remain found if they already have been targets_found = jnp.logical_or(targets_found, state.targets.found) + state = State( searchers=searchers, targets=TargetState(pos=target_pos, found=targets_found), diff --git a/jumanji/environments/swarms/search_and_rescue/env_test.py b/jumanji/environments/swarms/search_and_rescue/env_test.py index 03d7f0eca..00955eb70 100644 --- a/jumanji/environments/swarms/search_and_rescue/env_test.py +++ b/jumanji/environments/swarms/search_and_rescue/env_test.py @@ -21,7 +21,7 @@ import py import pytest -from jumanji.environments.swarms.common.types import AgentState +from jumanji.environments.swarms.common.types import AgentParams, AgentState from jumanji.environments.swarms.search_and_rescue import SearchAndRescue from jumanji.environments.swarms.search_and_rescue.dynamics import RandomWalk from jumanji.environments.swarms.search_and_rescue.types import ( @@ -246,6 +246,72 @@ def test_target_detection(env: SearchAndRescue, key: chex.PRNGKey) -> None: assert timestep.reward[0] == 0 +def test_multi_target_detection(env: SearchAndRescue, key: chex.PRNGKey) -> None: + # Keep targets in one location + env._target_dynamics = RandomWalk(step_size=0.0) + env.searcher_params = AgentParams( + max_rotate=0.1, + max_accelerate=0.01, + min_speed=0.01, + max_speed=0.05, + view_angle=0.25, + ) + + # Agent facing wrong direction should not see target + state = State( + searchers=AgentState( + pos=jnp.array([[0.5, 0.5]]), heading=jnp.array([0.5 * jnp.pi]), speed=jnp.array([0.0]) + ), + targets=TargetState( + pos=jnp.array([[0.54, 0.5], [0.46, 0.5]]), found=jnp.array([False, False]) + ), + key=key, + ) + state, timestep = env.step(state, jnp.zeros((1, 2))) + assert not state.targets.found[0] + assert not state.targets.found[1] + assert timestep.reward[0] == 0 + + # Rotated agent should detect first target + state = State( + searchers=AgentState( + pos=state.searchers.pos, heading=jnp.array([0.0]), speed=state.searchers.speed + ), + targets=state.targets, + key=state.key, + ) + state, timestep = env.step(state, jnp.zeros((1, 2))) + assert state.targets.found[0] + assert not state.targets.found[1] + assert timestep.reward[0] == 1 + + # Rotated agent should not detect another agent + state = State( + searchers=AgentState( + pos=state.searchers.pos, heading=jnp.array([1.5 * jnp.pi]), speed=state.searchers.speed + ), + targets=state.targets, + key=state.key, + ) + state, timestep = env.step(state, jnp.zeros((1, 2))) + assert state.targets.found[0] + assert not state.targets.found[1] + assert timestep.reward[0] == 0 + + # Rotated agent again should see second agent + state = State( + searchers=AgentState( + pos=state.searchers.pos, heading=jnp.array([jnp.pi]), speed=state.searchers.speed + ), + targets=state.targets, + key=state.key, + ) + state, timestep = env.step(state, jnp.zeros((1, 2))) + assert state.targets.found[0] + assert state.targets.found[1] + assert timestep.reward[0] == 1 + + def test_search_and_rescue_render(monkeypatch: pytest.MonkeyPatch, env: SearchAndRescue) -> None: """Check that the render method builds the figure but does not display it.""" monkeypatch.setattr(plt, "show", lambda fig: None) diff --git a/jumanji/environments/swarms/search_and_rescue/test_utils.py b/jumanji/environments/swarms/search_and_rescue/test_utils.py index 01ff5a9cf..dfe81ae38 100644 --- a/jumanji/environments/swarms/search_and_rescue/test_utils.py +++ b/jumanji/environments/swarms/search_and_rescue/test_utils.py @@ -12,19 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial from typing import List import chex +import jax import jax.numpy as jnp import pytest from jumanji.environments.swarms.common.types import AgentState from jumanji.environments.swarms.search_and_rescue.dynamics import RandomWalk, TargetDynamics from jumanji.environments.swarms.search_and_rescue.types import TargetState -from jumanji.environments.swarms.search_and_rescue.utils import ( - reward_if_found_target, - target_has_been_found, -) +from jumanji.environments.swarms.search_and_rescue.utils import searcher_detect_targets def test_random_walk_dynamics(key: chex.PRNGKey) -> None: @@ -44,11 +43,11 @@ def test_random_walk_dynamics(key: chex.PRNGKey) -> None: [ ([0.1, 0.0], 0.0, 0.5, False, False, 1.0), ([0.1, 0.0], jnp.pi, 0.5, False, True, 1.0), - ([0.1, 0.0], jnp.pi, 0.5, True, True, 1.0), + ([0.1, 0.0], jnp.pi, 0.5, True, False, 1.0), ([0.9, 0.0], jnp.pi, 0.5, False, False, 1.0), ([0.9, 0.0], 0.0, 0.5, False, True, 1.0), - ([0.9, 0.0], 0.0, 0.5, True, True, 1.0), - ([0.0, 0.1], 1.5 * jnp.pi, 0.5, True, True, 1.0), + ([0.9, 0.0], 0.0, 0.5, True, False, 1.0), + ([0.0, 0.1], 1.5 * jnp.pi, 0.5, True, False, 1.0), ([0.1, 0.0], 0.5 * jnp.pi, 0.5, False, True, 1.0), ([0.1, 0.0], 0.5 * jnp.pi, 0.4, False, False, 1.0), ([0.4, 0.0], 0.0, 0.5, False, False, 1.0), @@ -74,14 +73,12 @@ def test_target_found( speed=0.0, ) - found = target_has_been_found(None, view_angle, target.pos, searcher, env_size=env_size) - reward = reward_if_found_target(None, view_angle, searcher, target, env_size=env_size) - - assert found == expected + found = jax.jit(partial(searcher_detect_targets, env_size=env_size, n_targets=1))( + None, + view_angle, + searcher, + (jnp.arange(1), target), + ) - if found and target_state: - assert reward == 0.0 - elif found and not target_state: - assert reward == 1.0 - elif not found: - assert reward == 0.0 + assert found.shape == (1,) + assert found[0] == expected diff --git a/jumanji/environments/swarms/search_and_rescue/utils.py b/jumanji/environments/swarms/search_and_rescue/utils.py index ea477c270..c6f581bdd 100644 --- a/jumanji/environments/swarms/search_and_rescue/utils.py +++ b/jumanji/environments/swarms/search_and_rescue/utils.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Tuple + import chex import jax.numpy as jnp from esquilax.utils import shortest_vector @@ -47,50 +49,19 @@ def _check_target_in_view( return (dh >= -searcher_view_angle) & (dh <= searcher_view_angle) -def target_has_been_found( - _key: chex.PRNGKey, - searcher_view_angle: float, - target_pos: chex.Array, - searcher: AgentState, - *, - env_size: float, -) -> chex.Array: - """ - Returns True a target has been found. - - Return true if a target is within detection range - and within the view cone of a searcher. Used - to mark targets as found. - - Args: - _key: Dummy random key (required by Esquilax). - searcher_view_angle: View angle of searching agents - representing a fraction of pi from the agents heading. - target_pos: jax array (float) if shape (2,) representing - the position of the target. - searcher: Searcher agent state (i.e. position and heading). - env_size: size of the environment. - - Returns: - is-found: `bool` True if the target had been found/detected. - """ - return _check_target_in_view( - searcher.pos, target_pos, searcher.heading, searcher_view_angle, env_size - ) - - -def reward_if_found_target( +def searcher_detect_targets( _key: chex.PRNGKey, searcher_view_angle: float, searcher: AgentState, - target: TargetState, + target: Tuple[chex.Array, TargetState], *, env_size: float, + n_targets: int, ) -> chex.Array: """ - Return +1.0 reward if the agent has detected an agent. + Return array of flags indicating if a target has been located - Generate rewards for agents if a target is inside the + Sets the flag at the target index if the target is within the searchers view cone, and had not already been detected. Args: @@ -99,14 +70,17 @@ def reward_if_found_target( representing a fraction of pi from the agents heading. searcher: State of the searching agent (i.e. the agent position and heading) - target: State of the target (i.e. its position and + target: Index and State of the target (i.e. its position and search status). env_size: size of the environment. + n_targets: Number of search targets (static). Returns: - reward: +1.0 reward if the agent detects a new target. + array of boolean flags, set if a target at the index has been found. """ + target_idx, target = target + target_found = jnp.zeros((n_targets,), dtype=bool) can_see = _check_target_in_view( searcher.pos, target.pos, searcher.heading, searcher_view_angle, env_size ) - return (~target.found & can_see).astype(float) + return target_found.at[target_idx].set(jnp.logical_and(~target.found, can_see)) diff --git a/jumanji/environments/swarms/search_and_rescue/viewer.py b/jumanji/environments/swarms/search_and_rescue/viewer.py index 5f2257292..9a48103f7 100644 --- a/jumanji/environments/swarms/search_and_rescue/viewer.py +++ b/jumanji/environments/swarms/search_and_rescue/viewer.py @@ -42,6 +42,10 @@ def __init__( Args: figure_name: the window name to be used when initialising the window. figure_size: tuple (height, width) of the matplotlib figure window. + searcher_color: Color of searcher agent markers (arrows). + target_found_color: Color of target markers when they have been found. + target_lost_color: Color of target markers when they are still to be found. + env_size: Tuple environment spatial dimensions, used to set the plot region. """ self._figure_name = figure_name self._figure_size = figure_size From 9a654b93caf7cb7a9026af281cfa7c795aa52e76 Mon Sep 17 00:00:00 2001 From: zombie-einstein <13398815+zombie-einstein@users.noreply.github.com> Date: Sat, 7 Dec 2024 16:29:47 +0000 Subject: [PATCH 10/19] feat: training and customisable observations (#7) * Prototype search-and-rescue network and training * Refactor and docstrings * Share rewards if found by multiple agents * Use Distrax for normal distribution * Add bijector for continuous action space * Reshape returned from actor-critic * Prototype tanh bijector w clipping * Fix random agent and cleanup * Customisable reward aggregation * Cleanup * Configurable vision model * Docstrings and cleanup params --- jumanji/__init__.py | 7 + .../environments/swarms/common/test_common.py | 2 +- jumanji/environments/swarms/common/updates.py | 51 +++- .../swarms/search_and_rescue/conftest.py | 3 - .../swarms/search_and_rescue/env.py | 148 +++++----- .../swarms/search_and_rescue/env_test.py | 79 +----- .../swarms/search_and_rescue/observations.py | 255 +++++++++++++++++ .../search_and_rescue/observations_test.py | 265 ++++++++++++++++++ .../swarms/search_and_rescue/reward.py | 63 +++++ .../swarms/search_and_rescue/reward_test.py | 33 +++ .../swarms/search_and_rescue/test_utils.py | 4 +- .../configs/env/search_and_rescue.yaml | 24 ++ jumanji/training/networks/__init__.py | 4 + jumanji/training/networks/distribution.py | 25 ++ .../networks/parametric_distribution.py | 67 ++++- jumanji/training/networks/postprocessor.py | 18 ++ .../networks/search_and_rescue/__init__.py | 13 + .../search_and_rescue/actor_critic.py | 104 +++++++ .../networks/search_and_rescue/random.py | 51 ++++ jumanji/training/setup_train.py | 14 +- pyproject.toml | 1 + requirements/requirements-train.txt | 1 + 22 files changed, 1052 insertions(+), 180 deletions(-) create mode 100644 jumanji/environments/swarms/search_and_rescue/observations.py create mode 100644 jumanji/environments/swarms/search_and_rescue/observations_test.py create mode 100644 jumanji/environments/swarms/search_and_rescue/reward.py create mode 100644 jumanji/environments/swarms/search_and_rescue/reward_test.py create mode 100644 jumanji/training/configs/env/search_and_rescue.yaml create mode 100644 jumanji/training/networks/search_and_rescue/__init__.py create mode 100644 jumanji/training/networks/search_and_rescue/actor_critic.py create mode 100644 jumanji/training/networks/search_and_rescue/random.py diff --git a/jumanji/__init__.py b/jumanji/__init__.py index 60ba1da88..ec88111c3 100644 --- a/jumanji/__init__.py +++ b/jumanji/__init__.py @@ -142,3 +142,10 @@ # LevelBasedForaging with a random generator with 8 grid size, # 2 agents and 2 food items and the maximum agent's level is 2. register(id="LevelBasedForaging-v0", entry_point="jumanji.environments:LevelBasedForaging") + +### +# Swarm Environments +### + +# Search-and-Rescue environment +register(id="SearchAndRescue-v0", entry_point="jumanji.environments:SearchAndRescue") diff --git a/jumanji/environments/swarms/common/test_common.py b/jumanji/environments/swarms/common/test_common.py index cc3ebd67e..a04449df5 100644 --- a/jumanji/environments/swarms/common/test_common.py +++ b/jumanji/environments/swarms/common/test_common.py @@ -151,7 +151,7 @@ def test_view_reduction() -> None: [[0.05, 0.0], 0.25, 1.0, [-1.0, -1.0, 0.5, -1.0, -1.0]], [[0.0, 0.05], 0.25, 1.0, [-1.0, -1.0, -1.0, -1.0, -1.0]], [[0.0, 0.95], 0.25, 1.0, [-1.0, -1.0, -1.0, -1.0, -1.0]], - [[0.01, 0.0], 0.5, 1.0, [-1.0, -1.0, 0.1, -1.0, -1.0]], + [[0.01, 0.0], 0.5, 1.0, [-1.0, 0.1, 0.1, 0.1, -1.0]], [[0.0, 0.45], 0.5, 1.0, [4.5, -1.0, -1.0, -1.0, -1.0]], [[0.0, 0.45], 0.5, 0.5, [-1.0, -1.0, -1.0, -1.0, 0.5]], ], diff --git a/jumanji/environments/swarms/common/updates.py b/jumanji/environments/swarms/common/updates.py index fc7d3f5c1..6df21755f 100644 --- a/jumanji/environments/swarms/common/updates.py +++ b/jumanji/environments/swarms/common/updates.py @@ -128,6 +128,38 @@ def view_reduction(view_a: chex.Array, view_b: chex.Array) -> chex.Array: ) +def angular_width( + viewing_pos: chex.Array, + viewed_pos: chex.Array, + viewing_heading: chex.Array, + i_range: float, + agent_radius: float, + env_size: float, +) -> Tuple[chex.Array, chex.Array, chex.Array]: + """ + Get the normalised distance, and left and right angles to another agent. + + Args: + viewing_pos: Co-ordinates of the viewing agent + viewed_pos: Co-ordinates of the viewed agent + viewing_heading: Heading of viewing agent + i_range: Interaction range + agent_radius: Agent visual radius + env_size: Environment size + + Returns: + Normalised distance between agents, and the left and right + angles to the edges of the agent. + """ + dx = esquilax.utils.shortest_vector(viewing_pos, viewed_pos, length=env_size) + dist = jnp.sqrt(jnp.sum(dx * dx)) + phi = jnp.arctan2(dx[1], dx[0]) % (2 * jnp.pi) + dh = esquilax.utils.shortest_vector(phi, viewing_heading, 2 * jnp.pi) + a_width = jnp.arctan2(agent_radius, dist) + norm_dist = dist / i_range + return norm_dist, dh - a_width, dh + a_width + + def view( _key: chex.PRNGKey, params: Tuple[float, float], @@ -163,21 +195,20 @@ def view( jax array (float32): 1D array representing the distance along a ray from the agent to another agent. """ - view_angle, radius = params + view_angle, agent_radius = params rays = jnp.linspace( -view_angle * jnp.pi, view_angle * jnp.pi, n_view, endpoint=True, ) - dx = esquilax.utils.shortest_vector(viewing_agent.pos, viewed_agent.pos, length=env_size) - d = jnp.sqrt(jnp.sum(dx * dx)) / i_range - phi = jnp.arctan2(dx[1], dx[0]) % (2 * jnp.pi) - dh = esquilax.utils.shortest_vector(phi, viewing_agent.heading, 2 * jnp.pi) - - angular_width = jnp.arctan2(radius, d) - left = dh - angular_width - right = dh + angular_width - + d, left, right = angular_width( + viewing_agent.pos, + viewed_agent.pos, + viewing_agent.heading, + i_range, + agent_radius, + env_size, + ) obs = jnp.where(jnp.logical_and(left < rays, rays < right), d, -1.0) return obs diff --git a/jumanji/environments/swarms/search_and_rescue/conftest.py b/jumanji/environments/swarms/search_and_rescue/conftest.py index ddb82e2b8..70cb1b907 100644 --- a/jumanji/environments/swarms/search_and_rescue/conftest.py +++ b/jumanji/environments/swarms/search_and_rescue/conftest.py @@ -22,10 +22,7 @@ @pytest.fixture def env() -> SearchAndRescue: return SearchAndRescue( - searcher_vision_range=0.2, target_contact_range=0.05, - num_vision=11, - agent_radius=0.05, searcher_max_rotate=0.2, searcher_max_accelerate=0.01, searcher_min_speed=0.01, diff --git a/jumanji/environments/swarms/search_and_rescue/env.py b/jumanji/environments/swarms/search_and_rescue/env.py index a495f621e..267975854 100644 --- a/jumanji/environments/swarms/search_and_rescue/env.py +++ b/jumanji/environments/swarms/search_and_rescue/env.py @@ -24,10 +24,15 @@ from jumanji import specs from jumanji.env import Environment from jumanji.environments.swarms.common.types import AgentParams -from jumanji.environments.swarms.common.updates import update_state, view, view_reduction +from jumanji.environments.swarms.common.updates import update_state from jumanji.environments.swarms.search_and_rescue import utils from jumanji.environments.swarms.search_and_rescue.dynamics import RandomWalk, TargetDynamics from jumanji.environments.swarms.search_and_rescue.generator import Generator, RandomGenerator +from jumanji.environments.swarms.search_and_rescue.observations import ( + AgentAndTargetObservationFn, + ObservationFn, +) +from jumanji.environments.swarms.search_and_rescue.reward import RewardFn, SharedRewardFn from jumanji.environments.swarms.search_and_rescue.types import Observation, State, TargetState from jumanji.environments.swarms.search_and_rescue.viewer import SearchAndRescueViewer from jumanji.types import TimeStep, restart, termination, transition @@ -44,12 +49,16 @@ class SearchAndRescue(Environment): (i.e. the location of other agents) via a simple segmented view model. The environment consists of a uniform space with wrapped boundaries. + An episode will terminate if all targets have been located by the team of + searching agents. + - observation: `Observation` - searcher_views: jax array (float) of shape (num_searchers, num_vision) - individual local views of positions of other searching agents. - Each entry in the view indicates the distant to the nearest neighbour - along a ray from the agent, and is -1.0 if no agent is in range - along the ray. + searcher_views: jax array (float) of shape (num_searchers, channels, num_vision) + Individual local views of positions of other agents and targets, where + channels can be used to differentiate between agents and targets. + Each entry in the view indicates the distant to another agent/target + along a ray from the agent, and is -1.0 if nothing is in range along the ray. + The view model can be customised using an `ObservationFn` implementation. targets_remaining: (float) Number of targets remaining to be found from the total scaled to the range [0, 1] (i.e. a value of 1.0 indicates all the targets are still to be found). @@ -63,9 +72,11 @@ class SearchAndRescue(Environment): given parameters. - reward: jax array (float) of shape (num_searchers,) - Arrays of individual agent rewards. Rewards are granted when an agent + Arrays of individual agent rewards. A reward of +1 is granted when an agent comes into contact range with a target that has not yet been found, and - that agent is within the searchers view cone. + the target is within the searchers view cone. Rewards can be shared + between agents if a target is simultaneously detected by multiple agents, + or each can be provided the full reward individually. - state: `State` - searchers: `AgentState` @@ -81,20 +92,10 @@ class SearchAndRescue(Environment): - key: jax array (uint32) of shape (2,) - step: int representing the current simulation step. - ```python from jumanji.environments import SearchAndRescue - env = SearchAndRescue( - searcher_vision_range=0.1, - target_contact_range=0.01, - num_vision=40, - agent_radius0.01, - searcher_max_rotate=0.1, - searcher_max_accelerate=0.01, - searcher_min_speed=0.01, - searcher_max_speed=0.05, - searcher_view_angle=0.5, - ) + + env = SearchAndRescue() key = jax.random.PRNGKey(0) state, timestep = jax.jit(env.reset)(key) env.render(state) @@ -106,39 +107,23 @@ class SearchAndRescue(Environment): def __init__( self, - searcher_vision_range: float, - target_contact_range: float, - num_vision: int, - agent_radius: float, - searcher_max_rotate: float, - searcher_max_accelerate: float, - searcher_min_speed: float, - searcher_max_speed: float, - searcher_view_angle: float, + target_contact_range: float = 0.04, + searcher_max_rotate: float = 0.25, + searcher_max_accelerate: float = 0.005, + searcher_min_speed: float = 0.01, + searcher_max_speed: float = 0.02, + searcher_view_angle: float = 0.75, max_steps: int = 400, viewer: Optional[Viewer[State]] = None, target_dynamics: Optional[TargetDynamics] = None, generator: Optional[Generator] = None, + reward_fn: Optional[RewardFn] = None, + observation: Optional[ObservationFn] = None, ) -> None: """Instantiates a `SearchAndRescue` environment - Note: - The environment is square with dimensions - `[1.0, 1.0]` so parameters should be scaled - appropriately. Also note that performance is - dependent on agent vision and interaction ranges, - where larger values can lead to large number of - agent interactions. - Args: - searcher_vision_range: Search agent vision range. - target_contact_range: Range at which a searcher can 'find' a target. - num_vision: Number of cells/subdivisions in agent - view models. Larger numbers provide a more accurate - view, at the cost of the environment, at the cost - of performance and memory usage. - agent_radius: Radius of individual agents. This - effects how large they appear to other agents. + target_contact_range: Range at which a searchers will 'find' a target. searcher_max_rotate: Maximum rotation searcher agents can turn within a step. Should be a value from [0,1] representing a fraction of pi radians. @@ -146,10 +131,11 @@ def __init__( a searcher agent can apply within a step. searcher_min_speed: Minimum speed a searcher agent can move at. searcher_max_speed: Maximum speed a searcher agent can move at. - searcher_view_angle: Predator agent local view angle. Should be + searcher_view_angle: Searcher agent local view angle. Should be a value from [0,1] representing a fraction of pi radians. The view cone of an agent goes from +- of the view angle - relative to its heading. + relative to its heading, e.g. 0.5 would mean searchers have a + 90° view angle in total. max_steps: Maximum number of environment steps allowed for search. viewer: `Viewer` used for rendering. Defaults to `SearchAndRescueViewer`. target_dynamics: @@ -157,11 +143,12 @@ def __init__( `TargetDynamics` interface. Defaults to `RandomWalk`. generator: Initial state `Generator` instance. Defaults to `RandomGenerator` with 20 targets and 10 searchers. + reward_fn: Reward aggregation function. Defaults to `SharedRewardFn` where + agents share rewards if they locate a target simultaneously. """ - self.searcher_vision_range = searcher_vision_range + # self.searcher_vision_range = searcher_vision_range self.target_contact_range = target_contact_range - self.num_vision = num_vision - self.agent_radius = agent_radius + self.searcher_params = AgentParams( max_rotate=searcher_max_rotate, max_accelerate=searcher_max_accelerate, @@ -170,9 +157,17 @@ def __init__( view_angle=searcher_view_angle, ) self.max_steps = max_steps - self._target_dynamics = target_dynamics or RandomWalk(0.01) - self.generator = generator or RandomGenerator(num_targets=20, num_searchers=10) + self._target_dynamics = target_dynamics or RandomWalk(0.001) + self.generator = generator or RandomGenerator(num_targets=100, num_searchers=2) self._viewer = viewer or SearchAndRescueViewer() + self._reward_fn = reward_fn or SharedRewardFn() + self._observation = observation or AgentAndTargetObservationFn( + num_vision=64, + vision_range=0.1, + view_angle=searcher_view_angle, + agent_radius=0.01, + env_size=self.generator.env_size, + ) super().__init__() def __repr__(self) -> str: @@ -181,14 +176,16 @@ def __repr__(self) -> str: "Search & rescue multi-agent environment:", f" - num searchers: {self.generator.num_searchers}", f" - num targets: {self.generator.num_targets}", - f" - search vision range: {self.searcher_vision_range}", + f" - search vision range: {self._observation.vision_range}", f" - target contact range: {self.target_contact_range}", - f" - num vision: {self.num_vision}", - f" - agent radius: {self.agent_radius}", + f" - num vision: {self._observation.num_vision}", + f" - agent radius: {self._observation.agent_radius}", f" - max steps: {self.max_steps}," f" - env size: {self.generator.env_size}" f" - target dynamics: {self._target_dynamics.__class__.__name__}", f" - generator: {self.generator.__class__.__name__}", + f" - reward fn: {self._reward_fn.__class__.__name__}", + f" - observation fn: {self._observation.__class__.__name__}", ] ) @@ -229,25 +226,27 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser # Ensure target positions are wrapped target_pos = self._target_dynamics(target_key, state.targets.pos) % self.generator.env_size # Searchers return an array of flags of any targets they are in range of, - # and that have not already been locating, result shape here is (n-searcher, n-targets) + # and that have not already been located, result shape here is (n-searcher, n-targets) + n_targets = target_pos.shape[0] targets_found = spatial( utils.searcher_detect_targets, - reduction=jnp.add, - default=jnp.zeros((target_pos.shape[0],), dtype=bool), + reduction=jnp.logical_or, + default=jnp.zeros((n_targets,), dtype=bool), i_range=self.target_contact_range, dims=self.generator.env_size, )( key, self.searcher_params.view_angle, searchers, - (jnp.arange(target_pos.shape[0]), state.targets), + (jnp.arange(n_targets), state.targets), pos=searchers.pos, pos_b=target_pos, env_size=self.generator.env_size, - n_targets=target_pos.shape[0], + n_targets=n_targets, ) - rewards = jnp.sum(targets_found, axis=1) + rewards = self._reward_fn(targets_found) + targets_found = jnp.any(targets_found, axis=0) # Targets need to remain found if they already have been targets_found = jnp.logical_or(targets_found, state.targets.found) @@ -259,6 +258,7 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser step=state.step + 1, ) observation = self._state_to_observation(state) + observation = jax.lax.stop_gradient(observation) timestep = jax.lax.cond( jnp.logical_or(state.step >= self.max_steps, jnp.all(targets_found)), termination, @@ -269,24 +269,7 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser return state, timestep def _state_to_observation(self, state: State) -> Observation: - searcher_views = spatial( - view, - reduction=view_reduction, - default=-jnp.ones((self.num_vision,)), - include_self=False, - i_range=self.searcher_vision_range, - dims=self.generator.env_size, - )( - state.key, - (self.searcher_params.view_angle, self.agent_radius), - state.searchers, - state.searchers, - pos=state.searchers.pos, - n_view=self.num_vision, - i_range=self.searcher_vision_range, - env_size=self.generator.env_size, - ) - + searcher_views = self._observation(state) return Observation( searcher_views=searcher_views, targets_remaining=1.0 - jnp.sum(state.targets.found) / self.generator.num_targets, @@ -297,15 +280,14 @@ def _state_to_observation(self, state: State) -> Observation: def observation_spec(self) -> specs.Spec[Observation]: """Returns the observation spec. - Local searcher agent views representing - the distance to the closest neighbouring agents in the - environment. + Local searcher agent views representing the distance to the + closest neighbouring agents and targets in the environment. Returns: observation_spec: Search-and-rescue observation spec """ searcher_views = specs.BoundedArray( - shape=(self.generator.num_searchers, self.num_vision), + shape=(self.generator.num_searchers, *self._observation.view_shape), minimum=-1.0, maximum=1.0, dtype=float, diff --git a/jumanji/environments/swarms/search_and_rescue/env_test.py b/jumanji/environments/swarms/search_and_rescue/env_test.py index 00955eb70..998045c9b 100644 --- a/jumanji/environments/swarms/search_and_rescue/env_test.py +++ b/jumanji/environments/swarms/search_and_rescue/env_test.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple +from typing import Tuple import chex import jax @@ -57,7 +57,7 @@ def test_env_init(env: SearchAndRescue, key: chex.PRNGKey) -> None: assert isinstance(timestep.observation, Observation) assert timestep.observation.searcher_views.shape == ( env.generator.num_searchers, - env.num_vision, + *env._observation.view_shape, ) assert timestep.step_type == StepType.FIRST @@ -123,81 +123,6 @@ def test_env_specs_do_not_smoke(env: SearchAndRescue) -> None: check_env_specs_does_not_smoke(env) -@pytest.mark.parametrize( - "searcher_positions, searcher_headings, env_size, view_updates", - [ - # Both out of view range - ([[0.8, 0.5], [0.2, 0.5]], [jnp.pi, 0.0], 1.0, []), - # Both view each other - ([[0.25, 0.5], [0.2, 0.5]], [jnp.pi, 0.0], 1.0, [(0, 5, 0.25), (1, 5, 0.25)]), - # One facing wrong direction - ( - [[0.25, 0.5], [0.2, 0.5]], - [jnp.pi, jnp.pi], - 1.0, - [(0, 5, 0.25)], - ), - # Only see closest neighbour - ( - [[0.35, 0.5], [0.25, 0.5], [0.2, 0.5]], - [jnp.pi, 0.0, 0.0], - 1.0, - [(0, 5, 0.5), (1, 5, 0.5), (2, 5, 0.25)], - ), - # Observed around wrapped edge - ( - [[0.025, 0.5], [0.975, 0.5]], - [jnp.pi, 0.0], - 1.0, - [(0, 5, 0.25), (1, 5, 0.25)], - ), - # Observed around wrapped edge of smaller env - ( - [[0.025, 0.25], [0.475, 0.25]], - [jnp.pi, 0.0], - 0.5, - [(0, 5, 0.25), (1, 5, 0.25)], - ), - ], -) -def test_searcher_view( - env: SearchAndRescue, - key: chex.PRNGKey, - searcher_positions: List[List[float]], - searcher_headings: List[float], - env_size: float, - view_updates: List[Tuple[int, int, float]], -) -> None: - """ - Test view model generates expected array with different - configurations of agents. - """ - env.generator.env_size = env_size - - searcher_positions = jnp.array(searcher_positions) - searcher_headings = jnp.array(searcher_headings) - searcher_speed = jnp.zeros(searcher_headings.shape) - - state = State( - searchers=AgentState( - pos=searcher_positions, heading=searcher_headings, speed=searcher_speed - ), - targets=TargetState(pos=jnp.zeros((1, 2)), found=jnp.zeros((1, 2), dtype=bool)), - key=key, - ) - - obs = env._state_to_observation(state) - - assert isinstance(obs, Observation) - - expected = jnp.full((searcher_headings.shape[0], env.num_vision), -1.0) - - for i, idx, val in view_updates: - expected = expected.at[i, idx].set(val) - - assert jnp.all(jnp.isclose(obs.searcher_views, expected)) - - def test_target_detection(env: SearchAndRescue, key: chex.PRNGKey) -> None: # Keep targets in one location env._target_dynamics = RandomWalk(step_size=0.0) diff --git a/jumanji/environments/swarms/search_and_rescue/observations.py b/jumanji/environments/swarms/search_and_rescue/observations.py new file mode 100644 index 000000000..a23498ed5 --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/observations.py @@ -0,0 +1,255 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Tuple + +import chex +import jax.numpy as jnp +from esquilax.transforms import spatial + +from jumanji.environments.swarms.common.types import AgentState +from jumanji.environments.swarms.common.updates import angular_width, view, view_reduction +from jumanji.environments.swarms.search_and_rescue.types import State, TargetState + + +class ObservationFn(abc.ABC): + def __init__( + self, + view_shape: Tuple[int, ...], + num_vision: int, + vision_range: float, + view_angle: float, + agent_radius: float, + env_size: float, + ) -> None: + """ + Base class for observation function mapping state to individual agent views. + + Args: + view_shape: Individual agent view shape. + num_vision: Size of vision array. + vision_range: Vision range. + view_angle: Agent view angle (as a fraction of pi). + agent_radius: Agent/target visual radius. + env_size: Environment size. + """ + self.view_shape = view_shape + self.num_vision = num_vision + self.vision_range = vision_range + self.view_angle = view_angle + self.agent_radius = agent_radius + self.env_size = env_size + + @abc.abstractmethod + def __call__(self, state: State) -> chex.Array: + """ + Generate agent view/observation from state + + Args: + state: Current simulation state + + Returns: + Array of individual agent views (n-agents, n-channels, n-vision). + """ + + +class AgentObservationFn(ObservationFn): + def __init__( + self, + num_vision: int, + vision_range: float, + view_angle: float, + agent_radius: float, + env_size: float, + ) -> None: + """ + Observation that only visualises other search agents in proximity. + + Args: + num_vision: Size of vision array. + vision_range: Vision range. + view_angle: Agent view angle (as a fraction of pi). + agent_radius: Agent/target visual radius. + env_size: Environment size. + """ + super().__init__( + (1, num_vision), + num_vision, + vision_range, + view_angle, + agent_radius, + env_size, + ) + + def __call__(self, state: State) -> chex.Array: + """ + Generate agent view/observation from state + + Args: + state: Current simulation state + + Returns: + Array of individual agent views of shape + (n-agents, 1, n-vision). + """ + searcher_views = spatial( + view, + reduction=view_reduction, + default=-jnp.ones((self.num_vision,)), + include_self=False, + i_range=self.vision_range, + dims=self.env_size, + )( + state.key, + (self.view_angle, self.agent_radius), + state.searchers, + state.searchers, + pos=state.searchers.pos, + n_view=self.num_vision, + i_range=self.vision_range, + env_size=self.env_size, + ) + return searcher_views[:, jnp.newaxis] + + +def target_view( + _key: chex.PRNGKey, + params: Tuple[float, float], + searcher: AgentState, + target: TargetState, + *, + n_view: int, + i_range: float, + env_size: float, +) -> chex.Array: + """ + Return view of a target, dependent on target status. + + This function is intended to be mapped over agents-targets + by Esquilax. + + Args: + _key: Dummy random key (required by Esquilax). + params: View angle and target visual radius. + searcher: Searcher agent state + target: Target state + n_view: Number of value sin view array. + i_range: Vision range + env_size: Environment size + + Returns: + Segmented agent view of target. + """ + view_angle, agent_radius = params + rays = jnp.linspace( + -view_angle * jnp.pi, + view_angle * jnp.pi, + n_view, + endpoint=True, + ) + d, left, right = angular_width( + searcher.pos, + target.pos, + searcher.heading, + i_range, + agent_radius, + env_size, + ) + checks = jnp.logical_and(target.found, jnp.logical_and(left < rays, rays < right)) + obs = jnp.where(checks, d, -1.0) + return obs + + +class AgentAndTargetObservationFn(ObservationFn): + def __init__( + self, + num_vision: int, + vision_range: float, + view_angle: float, + agent_radius: float, + env_size: float, + ) -> None: + """ + Vision model that contains other agents, and found targets. + + Searchers and targets are visualised as individual channels. + Targets are only included if they have been located already. + + Args: + num_vision: Size of vision array. + vision_range: Vision range. + view_angle: Agent view angle (as a fraction of pi). + agent_radius: Agent/target visual radius. + env_size: Environment size. + """ + self.vision_range = vision_range + self.view_angle = view_angle + self.agent_radius = agent_radius + self.env_size = env_size + super().__init__( + (2, num_vision), + num_vision, + vision_range, + view_angle, + agent_radius, + env_size, + ) + + def __call__(self, state: State) -> chex.Array: + """ + Generate agent view/observation from state + + Args: + state: Current simulation state + + Returns: + Array of individual agent views + """ + searcher_views = spatial( + view, + reduction=view_reduction, + default=-jnp.ones((self.num_vision,)), + include_self=False, + i_range=self.vision_range, + dims=self.env_size, + )( + state.key, + (self.view_angle, self.agent_radius), + state.searchers, + state.searchers, + pos=state.searchers.pos, + n_view=self.num_vision, + i_range=self.vision_range, + env_size=self.env_size, + ) + target_views = spatial( + target_view, + reduction=view_reduction, + default=-jnp.ones((self.num_vision,)), + include_self=False, + i_range=self.vision_range, + dims=self.env_size, + )( + state.key, + (self.view_angle, self.agent_radius), + state.searchers, + state.targets, + pos=state.searchers.pos, + pos_b=state.targets.pos, + n_view=self.num_vision, + i_range=self.vision_range, + env_size=self.env_size, + ) + return jnp.hstack([searcher_views[:, jnp.newaxis], target_views[:, jnp.newaxis]]) diff --git a/jumanji/environments/swarms/search_and_rescue/observations_test.py b/jumanji/environments/swarms/search_and_rescue/observations_test.py new file mode 100644 index 000000000..32e74512d --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/observations_test.py @@ -0,0 +1,265 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple + +import chex +import jax.numpy as jnp +import pytest + +from jumanji.environments.swarms.common.types import AgentState +from jumanji.environments.swarms.search_and_rescue import SearchAndRescue, observations +from jumanji.environments.swarms.search_and_rescue.types import State, TargetState + +VISION_RANGE = 0.2 +VIEW_ANGLE = 0.5 + + +@pytest.mark.parametrize( + "searcher_positions, searcher_headings, env_size, view_updates", + [ + # Both out of view range + ([[0.8, 0.5], [0.2, 0.5]], [jnp.pi, 0.0], 1.0, []), + # Both view each other + ([[0.25, 0.5], [0.2, 0.5]], [jnp.pi, 0.0], 1.0, [(0, 5, 0.25), (1, 5, 0.25)]), + # One facing wrong direction + ( + [[0.25, 0.5], [0.2, 0.5]], + [jnp.pi, jnp.pi], + 1.0, + [(0, 5, 0.25)], + ), + # Only see closest neighbour + ( + [[0.35, 0.5], [0.25, 0.5], [0.2, 0.5]], + [jnp.pi, 0.0, 0.0], + 1.0, + [(0, 5, 0.5), (1, 5, 0.5), (2, 5, 0.25)], + ), + # Observed around wrapped edge + ( + [[0.025, 0.5], [0.975, 0.5]], + [jnp.pi, 0.0], + 1.0, + [(0, 5, 0.25), (1, 5, 0.25)], + ), + # Observed around wrapped edge of smaller env + ( + [[0.025, 0.25], [0.475, 0.25]], + [jnp.pi, 0.0], + 0.5, + [(0, 5, 0.25), (1, 5, 0.25)], + ), + ], +) +def test_searcher_view( + key: chex.PRNGKey, + # env: SearchAndRescue, + searcher_positions: List[List[float]], + searcher_headings: List[float], + env_size: float, + view_updates: List[Tuple[int, int, float]], +) -> None: + """ + Test agent-only view model generates expected array with different + configurations of agents. + """ + + searcher_positions = jnp.array(searcher_positions) + searcher_headings = jnp.array(searcher_headings) + searcher_speed = jnp.zeros(searcher_headings.shape) + + state = State( + searchers=AgentState( + pos=searcher_positions, heading=searcher_headings, speed=searcher_speed + ), + targets=TargetState(pos=jnp.zeros((1, 2)), found=jnp.zeros((1, 2), dtype=bool)), + key=key, + ) + + observe_fn = observations.AgentObservationFn( + num_vision=11, + vision_range=VISION_RANGE, + view_angle=VIEW_ANGLE, + agent_radius=0.01, + env_size=env_size, + ) + + obs = observe_fn(state) + + expected = jnp.full((searcher_headings.shape[0], 1, observe_fn.num_vision), -1.0) + + for i, idx, val in view_updates: + expected = expected.at[i, 0, idx].set(val) + + assert jnp.all(jnp.isclose(obs, expected)) + + +@pytest.mark.parametrize( + "searcher_positions, searcher_headings, env_size, view_updates", + [ + # Both out of view range + ([[0.8, 0.5], [0.2, 0.5]], [jnp.pi, 0.0], 1.0, []), + # Both view each other + ([[0.25, 0.5], [0.2, 0.5]], [jnp.pi, 0.0], 1.0, [(0, 5, 0.25), (1, 5, 0.25)]), + # One facing wrong direction + ( + [[0.25, 0.5], [0.2, 0.5]], + [jnp.pi, jnp.pi], + 1.0, + [(0, 5, 0.25)], + ), + # Only see closest neighbour + ( + [[0.35, 0.5], [0.25, 0.5], [0.2, 0.5]], + [jnp.pi, 0.0, 0.0], + 1.0, + [(0, 5, 0.5), (1, 5, 0.5), (2, 5, 0.25)], + ), + # Observed around wrapped edge + ( + [[0.025, 0.5], [0.975, 0.5]], + [jnp.pi, 0.0], + 1.0, + [(0, 5, 0.25), (1, 5, 0.25)], + ), + # Observed around wrapped edge of smaller env + ( + [[0.025, 0.25], [0.475, 0.25]], + [jnp.pi, 0.0], + 0.5, + [(0, 5, 0.25), (1, 5, 0.25)], + ), + ], +) +def test_search_and_target_view_searchers( + key: chex.PRNGKey, + searcher_positions: List[List[float]], + searcher_headings: List[float], + env_size: float, + view_updates: List[Tuple[int, int, float]], +) -> None: + """ + Test agent+target view model generates expected array with different + configurations of agents only. + """ + + n_agents = len(searcher_headings) + searcher_positions = jnp.array(searcher_positions) + searcher_headings = jnp.array(searcher_headings) + searcher_speed = jnp.zeros(searcher_headings.shape) + + state = State( + searchers=AgentState( + pos=searcher_positions, heading=searcher_headings, speed=searcher_speed + ), + targets=TargetState(pos=jnp.zeros((1, 2)), found=jnp.zeros((1,), dtype=bool)), + key=key, + ) + + observe_fn = observations.AgentAndTargetObservationFn( + num_vision=11, + vision_range=VISION_RANGE, + view_angle=VIEW_ANGLE, + agent_radius=0.01, + env_size=env_size, + ) + + obs = observe_fn(state) + assert obs.shape == (n_agents, 2, observe_fn.num_vision) + + expected = jnp.full((n_agents, 2, observe_fn.num_vision), -1.0) + + for i, idx, val in view_updates: + expected = expected.at[i, 0, idx].set(val) + + assert jnp.all(jnp.isclose(obs, expected)) + + +@pytest.mark.parametrize( + "searcher_position, searcher_heading, target_position, target_found, env_size, view_updates", + [ + # Target out of view range + ([0.8, 0.5], jnp.pi, [0.2, 0.5], True, 1.0, []), + # Target in view and found + ([0.25, 0.5], jnp.pi, [0.2, 0.5], True, 1.0, [(5, 0.25)]), + # Target in view but not found + ([0.25, 0.5], jnp.pi, [0.2, 0.5], False, 1.0, []), + # Observed around wrapped edge + ( + [0.025, 0.5], + jnp.pi, + [0.975, 0.5], + True, + 1.0, + [(5, 0.25)], + ), + # Observed around wrapped edge of smaller env + ( + [0.025, 0.25], + jnp.pi, + [0.475, 0.25], + True, + 0.5, + [(5, 0.25)], + ), + ], +) +def test_search_and_target_view_targets( + key: chex.PRNGKey, + env: SearchAndRescue, + searcher_position: List[float], + searcher_heading: float, + target_position: List[float], + target_found: bool, + env_size: float, + view_updates: List[Tuple[int, float]], +) -> None: + """ + Test agent+target view model generates expected array with different + configurations of targets only. + """ + + searcher_position = jnp.array([searcher_position]) + searcher_heading = jnp.array([searcher_heading]) + searcher_speed = jnp.zeros((1,)) + target_position = jnp.array([target_position]) + target_found = jnp.array([target_found]) + + state = State( + searchers=AgentState(pos=searcher_position, heading=searcher_heading, speed=searcher_speed), + targets=TargetState( + pos=target_position, + found=target_found, + ), + key=key, + ) + + observe_fn = observations.AgentAndTargetObservationFn( + num_vision=11, + vision_range=VISION_RANGE, + view_angle=VIEW_ANGLE, + agent_radius=0.01, + env_size=env_size, + ) + + obs = observe_fn(state) + assert obs.shape == (1, 2, observe_fn.num_vision) + + expected = jnp.full((1, 2, observe_fn.num_vision), -1.0) + + for idx, val in view_updates: + expected = expected.at[0, 1, idx].set(val) + + assert jnp.all(jnp.isclose(obs, expected)) diff --git a/jumanji/environments/swarms/search_and_rescue/reward.py b/jumanji/environments/swarms/search_and_rescue/reward.py new file mode 100644 index 000000000..1217b86f7 --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/reward.py @@ -0,0 +1,63 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc + +import chex +import jax.numpy as jnp + + +class RewardFn(abc.ABC): + """Abstract class for `SearchAndRescue` rewards.""" + + @abc.abstractmethod + def __call__(self, found_targets: chex.Array) -> chex.Array: + """The reward function used in the `SearchAndRescue` environment. + + Args: + found_targets: Array of boolean flags indicating + + Returns: + Individual reward for each agent. + """ + + +class SharedRewardFn(RewardFn): + """ + Calculate per agent rewards from detected targets + + Targets detected by multiple agents share rewards. Agents + can receive rewards for detecting multiple targets. + """ + + def __call__(self, found_targets: chex.Array) -> chex.Array: + rewards = found_targets.astype(float) + norms = jnp.sum(rewards, axis=0)[jnp.newaxis] + rewards = jnp.where(norms > 0, rewards / norms, rewards) + rewards = jnp.sum(rewards, axis=1) + return rewards + + +class IndividualRewardFn(RewardFn): + """ + Calculate per agent rewards from detected targets + + Each agent that detects a target receives a +1 reward + even if a target is detected by multiple agents. + """ + + def __call__(self, found_targets: chex.Array) -> chex.Array: + rewards = found_targets.astype(float) + rewards = jnp.sum(rewards, axis=1) + return rewards diff --git a/jumanji/environments/swarms/search_and_rescue/reward_test.py b/jumanji/environments/swarms/search_and_rescue/reward_test.py new file mode 100644 index 000000000..303b48b5c --- /dev/null +++ b/jumanji/environments/swarms/search_and_rescue/reward_test.py @@ -0,0 +1,33 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax.numpy as jnp + +from jumanji.environments.swarms.search_and_rescue import reward + + +def test_rewards_from_found_targets() -> None: + targets_found = jnp.array([[False, True, True], [False, False, True]], dtype=bool) + + shared_rewards = reward.SharedRewardFn()(targets_found) + + assert shared_rewards.shape == (2,) + assert shared_rewards.dtype == jnp.float32 + assert jnp.allclose(shared_rewards, jnp.array([1.5, 0.5])) + + individual_rewards = reward.IndividualRewardFn()(targets_found) + + assert individual_rewards.shape == (2,) + assert individual_rewards.dtype == jnp.float32 + assert jnp.allclose(individual_rewards, jnp.array([2.0, 1.0])) diff --git a/jumanji/environments/swarms/search_and_rescue/test_utils.py b/jumanji/environments/swarms/search_and_rescue/test_utils.py index dfe81ae38..0f7328c43 100644 --- a/jumanji/environments/swarms/search_and_rescue/test_utils.py +++ b/jumanji/environments/swarms/search_and_rescue/test_utils.py @@ -23,7 +23,9 @@ from jumanji.environments.swarms.common.types import AgentState from jumanji.environments.swarms.search_and_rescue.dynamics import RandomWalk, TargetDynamics from jumanji.environments.swarms.search_and_rescue.types import TargetState -from jumanji.environments.swarms.search_and_rescue.utils import searcher_detect_targets +from jumanji.environments.swarms.search_and_rescue.utils import ( + searcher_detect_targets, +) def test_random_walk_dynamics(key: chex.PRNGKey) -> None: diff --git a/jumanji/training/configs/env/search_and_rescue.yaml b/jumanji/training/configs/env/search_and_rescue.yaml new file mode 100644 index 000000000..11c5b1b16 --- /dev/null +++ b/jumanji/training/configs/env/search_and_rescue.yaml @@ -0,0 +1,24 @@ +name: search_and_rescue +registered_version: SearchAndRescue-v0 + +network: + layers: [128, 128] + +training: + num_epochs: 50 + num_learner_steps_per_epoch: 400 + n_steps: 20 + total_batch_size: 128 + +evaluation: + eval_total_batch_size: 2000 + greedy_eval_total_batch_size: 2000 + +a2c: + normalize_advantage: False + discount_factor: 0.997 + bootstrapping_factor: 0.95 + l_pg: 1.0 + l_td: 1.0 + l_en: 0.01 + learning_rate: 3e-4 diff --git a/jumanji/training/networks/__init__.py b/jumanji/training/networks/__init__.py index 587941bd1..999ec42c8 100644 --- a/jumanji/training/networks/__init__.py +++ b/jumanji/training/networks/__init__.py @@ -80,6 +80,10 @@ make_actor_critic_networks_rubiks_cube, ) from jumanji.training.networks.rubiks_cube.random import make_random_policy_rubiks_cube +from jumanji.training.networks.search_and_rescue.actor_critic import ( + make_actor_critic_search_and_rescue, +) +from jumanji.training.networks.search_and_rescue.random import make_random_policy_search_and_rescue from jumanji.training.networks.sliding_tile_puzzle.actor_critic import ( make_actor_critic_networks_sliding_tile_puzzle, ) diff --git a/jumanji/training/networks/distribution.py b/jumanji/training/networks/distribution.py index 03262136d..cce681314 100644 --- a/jumanji/training/networks/distribution.py +++ b/jumanji/training/networks/distribution.py @@ -21,6 +21,7 @@ import chex import jax import jax.numpy as jnp +from distrax import Normal class Distribution(abc.ABC): @@ -85,3 +86,27 @@ def kl_divergence( # type: ignore[override] probs = jax.nn.softmax(self.logits) log_probs_other = jax.nn.log_softmax(other.logits) return jnp.sum(jnp.where(probs == 0, 0.0, probs * (log_probs - log_probs_other)), axis=-1) + + +class NormalDistribution(Distribution): + """Normal distribution (with log standard deviations).""" + + def __init__(self, means: chex.Array, log_stds: chex.Array): + self.dist = Normal(loc=means, scale=jnp.exp(log_stds)) + + def mode(self) -> chex.Array: + return self.dist.mode() + + def log_prob(self, x: chex.Array) -> chex.Array: + return self.dist.log_prob(x) + + def entropy(self) -> chex.Array: + return self.dist.entropy() + + def kl_divergence( # type: ignore[override] + self, other: NormalDistribution + ) -> chex.Array: + return self.dist.kl_divergence(other) + + def sample(self, seed: chex.PRNGKey) -> chex.Array: + return self.dist.sample(seed=seed) diff --git a/jumanji/training/networks/parametric_distribution.py b/jumanji/training/networks/parametric_distribution.py index 325c17e9b..e153128a0 100644 --- a/jumanji/training/networks/parametric_distribution.py +++ b/jumanji/training/networks/parametric_distribution.py @@ -15,17 +15,22 @@ """Adapted from Brax.""" import abc -from typing import Any +from typing import Any, Tuple import chex import jax.numpy as jnp import numpy as np -from jumanji.training.networks.distribution import CategoricalDistribution, Distribution +from jumanji.training.networks.distribution import ( + CategoricalDistribution, + Distribution, + NormalDistribution, +) from jumanji.training.networks.postprocessor import ( FactorisedActionSpaceReshapeBijector, IdentityBijector, Postprocessor, + TanhBijector, ) @@ -166,10 +171,64 @@ def __init__(self, action_spec_num_values: chex.ArrayNumpy): Args: action_spec_num_values: the dimensions of each of the factors in the action space""" num_actions = int(np.prod(action_spec_num_values)) - posprocessor = FactorisedActionSpaceReshapeBijector( + postprocessor = FactorisedActionSpaceReshapeBijector( action_spec_num_values=action_spec_num_values ) - super().__init__(param_size=num_actions, postprocessor=posprocessor, event_ndims=0) + super().__init__(param_size=num_actions, postprocessor=postprocessor, event_ndims=0) def create_dist(self, parameters: chex.Array) -> CategoricalDistribution: return CategoricalDistribution(logits=parameters) + + +class ContinuousActionSpaceNormalTanhDistribution(ParametricDistribution): + """Normal distribution for continuous action spaces""" + + def __init__(self, n_actions: int, threshold: float = 0.999): + super().__init__( + param_size=n_actions, + postprocessor=TanhBijector(), + event_ndims=1, + ) + self._inverse_threshold = self._postprocessor.inverse(threshold) + self._log_epsilon = jnp.log(1.0 - threshold) + + def create_dist(self, parameters: Tuple[chex.Array, chex.Array]) -> NormalDistribution: + return NormalDistribution(means=parameters[0], log_stds=parameters[1]) + + def log_prob( + self, parameters: Tuple[chex.Array, chex.Array], raw_actions: chex.Array + ) -> chex.Array: + """Compute the log probability of raw actions when transformed""" + dist = self.create_dist(parameters) + + log_prob_left = dist.dist.log_cdf(-self._inverse_threshold) - self._log_epsilon + log_prob_right = ( + dist.dist.log_survival_function(self._inverse_threshold) - self._log_epsilon + ) + + clipped_actions = jnp.clip(raw_actions, -self._inverse_threshold, self._inverse_threshold) + raw_log_probs = dist.log_prob(clipped_actions) + raw_log_probs -= self._postprocessor.forward_log_det_jacobian(clipped_actions) + + log_probs = jnp.where( + raw_actions <= -self._inverse_threshold, + log_prob_left, + jnp.where( + raw_actions >= self._inverse_threshold, + log_prob_right, + raw_log_probs, + ), + ) + # Sum over non-batch axes + log_probs = jnp.sum(log_probs, axis=tuple(range(1, log_probs.ndim))) + + return log_probs + + def entropy(self, parameters: Tuple[chex.Array, chex.Array], seed: chex.PRNGKey) -> chex.Array: + """Return the entropy of the given distribution.""" + dist = self.create_dist(parameters) + entropy = dist.entropy() + entropy += self._postprocessor.forward_log_det_jacobian(dist.sample(seed=seed)) + # Sum over non-batch axes + entropy = jnp.sum(entropy, axis=tuple(range(1, entropy.ndim))) + return entropy diff --git a/jumanji/training/networks/postprocessor.py b/jumanji/training/networks/postprocessor.py index e97fe3a19..20ba59421 100644 --- a/jumanji/training/networks/postprocessor.py +++ b/jumanji/training/networks/postprocessor.py @@ -17,6 +17,7 @@ import abc import chex +import distrax import jax.numpy as jnp @@ -47,6 +48,23 @@ def forward_log_det_jacobian(self, x: chex.Array) -> chex.Array: return jnp.zeros_like(x, x.dtype) +class TanhBijector(Postprocessor): + """Tanh Bijector for continuous actions.""" + + def __init__(self) -> None: + super().__init__() + self._tanh = distrax.Tanh() + + def forward(self, x: chex.Array) -> chex.Array: + return self._tanh.forward(x) + + def inverse(self, y: chex.Array) -> chex.Array: + return self._tanh.inverse(y) + + def forward_log_det_jacobian(self, x: chex.Array) -> chex.Array: + return self._tanh.forward_log_det_jacobian(x) + + class FactorisedActionSpaceReshapeBijector(Postprocessor): """Identity bijector that reshapes (flattens and unflattens) a sequential action.""" diff --git a/jumanji/training/networks/search_and_rescue/__init__.py b/jumanji/training/networks/search_and_rescue/__init__.py new file mode 100644 index 000000000..21db9ec1c --- /dev/null +++ b/jumanji/training/networks/search_and_rescue/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/jumanji/training/networks/search_and_rescue/actor_critic.py b/jumanji/training/networks/search_and_rescue/actor_critic.py new file mode 100644 index 000000000..092cd3508 --- /dev/null +++ b/jumanji/training/networks/search_and_rescue/actor_critic.py @@ -0,0 +1,104 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from math import prod +from typing import Sequence, Tuple, Union + +import chex +import haiku as hk +import jax.numpy as jnp + +from jumanji.environments.swarms.search_and_rescue import SearchAndRescue +from jumanji.environments.swarms.search_and_rescue.types import Observation +from jumanji.training.networks.actor_critic import ( + ActorCriticNetworks, + FeedForwardNetwork, +) +from jumanji.training.networks.parametric_distribution import ( + ContinuousActionSpaceNormalTanhDistribution, +) + + +def make_actor_critic_search_and_rescue( + search_and_rescue: SearchAndRescue, + layers: Sequence[int], +) -> ActorCriticNetworks: + """ + Initialise networks for the search-and-rescue network. + + Note: This network is intended to accept environment observations + with agent views of shape [n-agents, n-view], but then + Returns a flattened array of actions for each agent (these + are reshaped by the wrapped environment). + + Args: + search_and_rescue: `SearchAndRescue` environment. + layers: List of hidden layer dimensions. + + Returns: + Continuous action space MLP action and critic networks. + """ + n_actions = prod(search_and_rescue.action_spec.shape) + parametric_action_distribution = ContinuousActionSpaceNormalTanhDistribution(n_actions) + policy_network = make_actor_network(layers=layers, n_actions=n_actions) + value_network = make_critic_network(layers=layers) + + return ActorCriticNetworks( + policy_network=policy_network, + value_network=value_network, + parametric_action_distribution=parametric_action_distribution, + ) + + +def make_critic_network(layers: Sequence[int]) -> FeedForwardNetwork: + # Shape names: + # B: batch size + # N: number of agents + # O: observation size + + def network_fn(observation: Observation) -> Union[chex.Array, Tuple[chex.Array, chex.Array]]: + views = observation.searcher_views # (B, N, O) + batch_size = views.shape[0] + views = views.reshape(batch_size, -1) # (B, N * O) + value = hk.nets.MLP([*layers, 1])(views) # (B,) + return jnp.squeeze(value, axis=-1) + + init, apply = hk.without_apply_rng(hk.transform(network_fn)) + return FeedForwardNetwork(init=init, apply=apply) + + +def make_actor_network(layers: Sequence[int], n_actions: int) -> FeedForwardNetwork: + # Shape names: + # B: batch size + # N: number of agents + # O: observation size + # A: Number of actions + + def network_fn(observation: Observation) -> Union[chex.Array, Tuple[chex.Array, chex.Array]]: + views = observation.searcher_views # (B, N, O) + batch_size = views.shape[0] + n_agents = views.shape[1] + views = views.reshape((batch_size, -1)) # (B, N * 0) + means = hk.nets.MLP([*layers, n_agents * n_actions])(views) # (B, N * A) + means = means.reshape(batch_size, n_agents, n_actions) # (B, N, A) + + log_stds = hk.get_parameter( + "log_stds", shape=(n_agents * n_actions,), init=hk.initializers.Constant(0.1) + ) # (N * A,) + log_stds = jnp.broadcast_to(log_stds, (batch_size, n_agents * n_actions)) # (B, N * A) + log_stds = log_stds.reshape(batch_size, n_agents, n_actions) # (B, N, A) + + return means, log_stds + + init, apply = hk.without_apply_rng(hk.transform(network_fn)) + return FeedForwardNetwork(init=init, apply=apply) diff --git a/jumanji/training/networks/search_and_rescue/random.py b/jumanji/training/networks/search_and_rescue/random.py new file mode 100644 index 000000000..8f5741c0b --- /dev/null +++ b/jumanji/training/networks/search_and_rescue/random.py @@ -0,0 +1,51 @@ +# Copyright 2022 InstaDeep Ltd. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import chex +import jax.random + +from jumanji.environments.swarms.search_and_rescue import SearchAndRescue +from jumanji.environments.swarms.search_and_rescue.types import Observation +from jumanji.training.networks.protocols import RandomPolicy + + +class SearchAndRescueRandomPolicy(RandomPolicy): + def __init__(self, n_actions: int): + self.n_actions = n_actions + + def __call__( + self, + observation: Observation, + key: chex.PRNGKey, + ) -> chex.Array: + """A random policy given an environment-specific observation. + + Args: + observation: environment observation. + key: random key for action selection. + + Returns: + action + """ + shape = ( + observation.searcher_views.shape[0], + observation.searcher_views.shape[1], + self.n_actions, + ) + return jax.random.uniform(key, shape, minval=-1.0, maxval=1.0) + + +def make_random_policy_search_and_rescue(search_and_rescue: SearchAndRescue) -> RandomPolicy: + """Make random policy for Search & Rescue.""" + return SearchAndRescueRandomPolicy(search_and_rescue.action_spec.shape[-1]) diff --git a/jumanji/training/setup_train.py b/jumanji/training/setup_train.py index 7daa0201d..b5fbdca1c 100644 --- a/jumanji/training/setup_train.py +++ b/jumanji/training/setup_train.py @@ -41,6 +41,7 @@ PacMan, RobotWarehouse, RubiksCube, + SearchAndRescue, SlidingTilePuzzle, Snake, Sokoban, @@ -90,7 +91,7 @@ def setup_logger(cfg: DictConfig) -> Logger: def _make_raw_env(cfg: DictConfig) -> Environment: env = jumanji.make(cfg.env.registered_version) - if cfg.env.name in {"lbf", "connector"}: + if cfg.env.name in {"lbf", "connector", "search_and_rescue"}: # Convert a multi-agent environment to a single-agent environment env = MultiToSingleWrapper(env) return env @@ -206,6 +207,11 @@ def _setup_random_policy(cfg: DictConfig, env: Environment) -> RandomPolicy: elif cfg.env.name == "lbf": assert isinstance(env.unwrapped, LevelBasedForaging) random_policy = networks.make_random_policy_lbf() + elif cfg.env.name == "search_and_rescue": + assert isinstance(env.unwrapped, SearchAndRescue) + random_policy = networks.make_random_policy_search_and_rescue( + search_and_rescue=env.unwrapped + ) else: raise ValueError(f"Environment name not found. Got {cfg.env.name}.") return random_policy @@ -421,6 +427,12 @@ def _setup_actor_critic_neworks(cfg: DictConfig, env: Environment) -> ActorCriti transformer_key_size=cfg.env.network.transformer_key_size, transformer_mlp_units=cfg.env.network.transformer_mlp_units, ) + elif cfg.env.name == "search_and_rescue": + assert isinstance(env.unwrapped, SearchAndRescue) + actor_critic_networks = networks.make_actor_critic_search_and_rescue( + search_and_rescue=env.unwrapped, + layers=cfg.env.network.layers, + ) else: raise ValueError(f"Environment name not found. Got {cfg.env.name}.") return actor_critic_networks diff --git a/pyproject.toml b/pyproject.toml index ea4fabe5c..f637dca31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,7 @@ module = [ "requests.*", "pkg_resources.*", "PIL.*", + "distrax.*", ] ignore_missing_imports = true diff --git a/requirements/requirements-train.txt b/requirements/requirements-train.txt index 890d0f983..55c0032e0 100644 --- a/requirements/requirements-train.txt +++ b/requirements/requirements-train.txt @@ -1,3 +1,4 @@ +distrax>=0.1.5 dm-haiku hydra-core==1.3 neptune-client==0.16.15 From 5021e2070250373b4289f9e61a97a63e817776e5 Mon Sep 17 00:00:00 2001 From: zombie-einstein <13398815+zombie-einstein@users.noreply.github.com> Date: Mon, 9 Dec 2024 11:10:41 +0000 Subject: [PATCH 11/19] feat: view all targets (#9) * Add observation including all targets * Consistent test module names * Use CNN embedding --- .../common/{test_common.py => common_test.py} | 0 .../{test_generator.py => generator_test.py} | 0 .../swarms/search_and_rescue/observations.py | 138 +++++++++++++++++- .../search_and_rescue/observations_test.py | 87 +++++++++++ .../{test_utils.py => utils_test.py} | 0 .../search_and_rescue/actor_critic.py | 48 +++--- 6 files changed, 253 insertions(+), 20 deletions(-) rename jumanji/environments/swarms/common/{test_common.py => common_test.py} (100%) rename jumanji/environments/swarms/search_and_rescue/{test_generator.py => generator_test.py} (100%) rename jumanji/environments/swarms/search_and_rescue/{test_utils.py => utils_test.py} (100%) diff --git a/jumanji/environments/swarms/common/test_common.py b/jumanji/environments/swarms/common/common_test.py similarity index 100% rename from jumanji/environments/swarms/common/test_common.py rename to jumanji/environments/swarms/common/common_test.py diff --git a/jumanji/environments/swarms/search_and_rescue/test_generator.py b/jumanji/environments/swarms/search_and_rescue/generator_test.py similarity index 100% rename from jumanji/environments/swarms/search_and_rescue/test_generator.py rename to jumanji/environments/swarms/search_and_rescue/generator_test.py diff --git a/jumanji/environments/swarms/search_and_rescue/observations.py b/jumanji/environments/swarms/search_and_rescue/observations.py index a23498ed5..f3f917e59 100644 --- a/jumanji/environments/swarms/search_and_rescue/observations.py +++ b/jumanji/environments/swarms/search_and_rescue/observations.py @@ -124,7 +124,7 @@ def __call__(self, state: State) -> chex.Array: return searcher_views[:, jnp.newaxis] -def target_view( +def found_target_view( _key: chex.PRNGKey, params: Tuple[float, float], searcher: AgentState, @@ -235,7 +235,7 @@ def __call__(self, state: State) -> chex.Array: env_size=self.env_size, ) target_views = spatial( - target_view, + found_target_view, reduction=view_reduction, default=-jnp.ones((self.num_vision,)), include_self=False, @@ -253,3 +253,137 @@ def __call__(self, state: State) -> chex.Array: env_size=self.env_size, ) return jnp.hstack([searcher_views[:, jnp.newaxis], target_views[:, jnp.newaxis]]) + + +def all_target_view( + _key: chex.PRNGKey, + params: Tuple[float, float], + searcher: AgentState, + target: TargetState, + *, + n_view: int, + i_range: float, + env_size: float, +) -> chex.Array: + """ + Return view of a target, dependent on target status. + + This function is intended to be mapped over agents-targets + by Esquilax. + + Args: + _key: Dummy random key (required by Esquilax). + params: View angle and target visual radius. + searcher: Searcher agent state + target: Target state + n_view: Number of value sin view array. + i_range: Vision range + env_size: Environment size + + Returns: + Segmented agent view of target. + """ + view_angle, agent_radius = params + rays = jnp.linspace( + -view_angle * jnp.pi, + view_angle * jnp.pi, + n_view, + endpoint=True, + ) + d, left, right = angular_width( + searcher.pos, + target.pos, + searcher.heading, + i_range, + agent_radius, + env_size, + ) + ray_checks = jnp.logical_and(left < rays, rays < right) + checks_a = jnp.logical_and(target.found, ray_checks) + checks_b = jnp.logical_and(~target.found, ray_checks) + obs = [jnp.where(checks_a, d, -1.0), jnp.where(checks_b, d, -1.0)] + obs = jnp.vstack(obs) + return obs + + +class AgentAndAllTargetObservationFn(ObservationFn): + def __init__( + self, + num_vision: int, + vision_range: float, + view_angle: float, + agent_radius: float, + env_size: float, + ) -> None: + """ + Vision model that contains other agents, and found targets. + + Searchers and targets are visualised as individual channels. + Targets are only included if they have been located already. + + Args: + num_vision: Size of vision array. + vision_range: Vision range. + view_angle: Agent view angle (as a fraction of pi). + agent_radius: Agent/target visual radius. + env_size: Environment size. + """ + self.vision_range = vision_range + self.view_angle = view_angle + self.agent_radius = agent_radius + self.env_size = env_size + super().__init__( + (3, num_vision), + num_vision, + vision_range, + view_angle, + agent_radius, + env_size, + ) + + def __call__(self, state: State) -> chex.Array: + """ + Generate agent view/observation from state + + Args: + state: Current simulation state + + Returns: + Array of individual agent views + """ + searcher_views = spatial( + view, + reduction=view_reduction, + default=-jnp.ones((self.num_vision,)), + include_self=False, + i_range=self.vision_range, + dims=self.env_size, + )( + state.key, + (self.view_angle, self.agent_radius), + state.searchers, + state.searchers, + pos=state.searchers.pos, + n_view=self.num_vision, + i_range=self.vision_range, + env_size=self.env_size, + ) + target_views = spatial( + all_target_view, + reduction=view_reduction, + default=-jnp.ones((2, self.num_vision)), + include_self=False, + i_range=self.vision_range, + dims=self.env_size, + )( + state.key, + (self.view_angle, self.agent_radius), + state.searchers, + state.targets, + pos=state.searchers.pos, + pos_b=state.targets.pos, + n_view=self.num_vision, + i_range=self.vision_range, + env_size=self.env_size, + ) + return jnp.hstack([searcher_views[:, jnp.newaxis], target_views]) diff --git a/jumanji/environments/swarms/search_and_rescue/observations_test.py b/jumanji/environments/swarms/search_and_rescue/observations_test.py index 32e74512d..3b02b3e42 100644 --- a/jumanji/environments/swarms/search_and_rescue/observations_test.py +++ b/jumanji/environments/swarms/search_and_rescue/observations_test.py @@ -263,3 +263,90 @@ def test_search_and_target_view_targets( expected = expected.at[0, 1, idx].set(val) assert jnp.all(jnp.isclose(obs, expected)) + + +@pytest.mark.parametrize( + "searcher_position, searcher_heading, target_position, target_found, env_size, view_updates", + [ + # Target out of view range + ([0.8, 0.5], jnp.pi, [0.2, 0.5], True, 1.0, []), + # Target in view and found + ([0.25, 0.5], jnp.pi, [0.2, 0.5], True, 1.0, [(1, 5, 0.25)]), + # Target in view but not found + ([0.25, 0.5], jnp.pi, [0.2, 0.5], False, 1.0, [(2, 5, 0.25)]), + # Observed around wrapped edge found + ( + [0.025, 0.5], + jnp.pi, + [0.975, 0.5], + True, + 1.0, + [(1, 5, 0.25)], + ), + # Observed around wrapped edge not found + ( + [0.025, 0.5], + jnp.pi, + [0.975, 0.5], + False, + 1.0, + [(2, 5, 0.25)], + ), + # Observed around wrapped edge of smaller env + ( + [0.025, 0.25], + jnp.pi, + [0.475, 0.25], + True, + 0.5, + [(1, 5, 0.25)], + ), + ], +) +def test_search_and_all_target_view_targets( + key: chex.PRNGKey, + env: SearchAndRescue, + searcher_position: List[float], + searcher_heading: float, + target_position: List[float], + target_found: bool, + env_size: float, + view_updates: List[Tuple[int, int, float]], +) -> None: + """ + Test agent+target view model generates expected array with different + configurations of targets only. + """ + + searcher_position = jnp.array([searcher_position]) + searcher_heading = jnp.array([searcher_heading]) + searcher_speed = jnp.zeros((1,)) + target_position = jnp.array([target_position]) + target_found = jnp.array([target_found]) + + state = State( + searchers=AgentState(pos=searcher_position, heading=searcher_heading, speed=searcher_speed), + targets=TargetState( + pos=target_position, + found=target_found, + ), + key=key, + ) + + observe_fn = observations.AgentAndAllTargetObservationFn( + num_vision=11, + vision_range=VISION_RANGE, + view_angle=VIEW_ANGLE, + agent_radius=0.01, + env_size=env_size, + ) + + obs = observe_fn(state) + assert obs.shape == (1, 3, observe_fn.num_vision) + + expected = jnp.full((1, 3, observe_fn.num_vision), -1.0) + + for i, idx, val in view_updates: + expected = expected.at[0, i, idx].set(val) + + assert jnp.all(jnp.isclose(obs, expected)) diff --git a/jumanji/environments/swarms/search_and_rescue/test_utils.py b/jumanji/environments/swarms/search_and_rescue/utils_test.py similarity index 100% rename from jumanji/environments/swarms/search_and_rescue/test_utils.py rename to jumanji/environments/swarms/search_and_rescue/utils_test.py diff --git a/jumanji/training/networks/search_and_rescue/actor_critic.py b/jumanji/training/networks/search_and_rescue/actor_critic.py index 092cd3508..93dd0f40b 100644 --- a/jumanji/training/networks/search_and_rescue/actor_critic.py +++ b/jumanji/training/networks/search_and_rescue/actor_critic.py @@ -16,6 +16,7 @@ import chex import haiku as hk +import jax import jax.numpy as jnp from jumanji.environments.swarms.search_and_rescue import SearchAndRescue @@ -50,7 +51,9 @@ def make_actor_critic_search_and_rescue( """ n_actions = prod(search_and_rescue.action_spec.shape) parametric_action_distribution = ContinuousActionSpaceNormalTanhDistribution(n_actions) - policy_network = make_actor_network(layers=layers, n_actions=n_actions) + policy_network = make_actor_network( + layers=layers, n_agents=search_and_rescue.generator.num_searchers, n_actions=n_actions + ) value_network = make_critic_network(layers=layers) return ActorCriticNetworks( @@ -60,43 +63,52 @@ def make_actor_critic_search_and_rescue( ) +def embedding(x: chex.Array) -> chex.Array: + n_channels = x.shape[-2] + x = hk.Conv1D(2 * n_channels, 3, data_format="NCW")(x) + x = jax.nn.relu(x) + x = hk.MaxPool(2, 2, "SAME", channel_axis=-2)(x) + return x + + def make_critic_network(layers: Sequence[int]) -> FeedForwardNetwork: # Shape names: # B: batch size # N: number of agents + # C: Observation channels # O: observation size def network_fn(observation: Observation) -> Union[chex.Array, Tuple[chex.Array, chex.Array]]: - views = observation.searcher_views # (B, N, O) - batch_size = views.shape[0] - views = views.reshape(batch_size, -1) # (B, N * O) - value = hk.nets.MLP([*layers, 1])(views) # (B,) + x = observation.searcher_views # (B, N, C, O) + x = hk.vmap(embedding, split_rng=False)(x) + x = hk.Flatten()(x) + value = hk.nets.MLP([*layers, 1])(x) # (B,) return jnp.squeeze(value, axis=-1) init, apply = hk.without_apply_rng(hk.transform(network_fn)) return FeedForwardNetwork(init=init, apply=apply) -def make_actor_network(layers: Sequence[int], n_actions: int) -> FeedForwardNetwork: +def make_actor_network(layers: Sequence[int], n_agents: int, n_actions: int) -> FeedForwardNetwork: # Shape names: # B: batch size # N: number of agents + # C: Observation channels # O: observation size # A: Number of actions + def log_std_params(x: chex.Array) -> chex.Array: + return hk.get_parameter("log_stds", shape=x.shape, init=hk.initializers.Constant(0.1)) + def network_fn(observation: Observation) -> Union[chex.Array, Tuple[chex.Array, chex.Array]]: - views = observation.searcher_views # (B, N, O) - batch_size = views.shape[0] - n_agents = views.shape[1] - views = views.reshape((batch_size, -1)) # (B, N * 0) - means = hk.nets.MLP([*layers, n_agents * n_actions])(views) # (B, N * A) - means = means.reshape(batch_size, n_agents, n_actions) # (B, N, A) - - log_stds = hk.get_parameter( - "log_stds", shape=(n_agents * n_actions,), init=hk.initializers.Constant(0.1) - ) # (N * A,) - log_stds = jnp.broadcast_to(log_stds, (batch_size, n_agents * n_actions)) # (B, N * A) - log_stds = log_stds.reshape(batch_size, n_agents, n_actions) # (B, N, A) + x = observation.searcher_views # (B, N, C, O) + x = hk.vmap(embedding, split_rng=False)(x) + x = hk.Flatten()(x) + + means = hk.nets.MLP([*layers, n_agents * n_actions])(x) # (B, N * A) + means = hk.Reshape(output_shape=(n_agents, n_actions))(means) # (B, N, A) + + log_stds = hk.vmap(log_std_params, split_rng=False)(means) # (B, N, A) return means, log_stds From 9e8ac5c8ba1f498ac6fe1303d53761ca65be1514 Mon Sep 17 00:00:00 2001 From: zombie-einstein <13398815+zombie-einstein@users.noreply.github.com> Date: Wed, 11 Dec 2024 00:08:49 +0000 Subject: [PATCH 12/19] feat: Scaled rewards and target velocities (#10) * Use channels view parameters * Rename parameters * Include step-number in observation * Add velocity field to targets * Add time scaled reward function --- .../swarms/search_and_rescue/conftest.py | 2 +- .../swarms/search_and_rescue/dynamics.py | 19 ++++---- .../swarms/search_and_rescue/env.py | 43 +++++++++++-------- .../swarms/search_and_rescue/env_test.py | 16 ++++--- .../swarms/search_and_rescue/generator.py | 5 ++- .../swarms/search_and_rescue/observations.py | 12 +++--- .../search_and_rescue/observations_test.py | 10 ++++- .../swarms/search_and_rescue/reward.py | 24 +++++++++-- .../swarms/search_and_rescue/reward_test.py | 16 ++++++- .../swarms/search_and_rescue/types.py | 6 ++- .../swarms/search_and_rescue/utils_test.py | 13 ++++-- 11 files changed, 115 insertions(+), 51 deletions(-) diff --git a/jumanji/environments/swarms/search_and_rescue/conftest.py b/jumanji/environments/swarms/search_and_rescue/conftest.py index 70cb1b907..6b63645aa 100644 --- a/jumanji/environments/swarms/search_and_rescue/conftest.py +++ b/jumanji/environments/swarms/search_and_rescue/conftest.py @@ -28,7 +28,7 @@ def env() -> SearchAndRescue: searcher_min_speed=0.01, searcher_max_speed=0.05, searcher_view_angle=0.5, - max_steps=25, + time_limit=10, ) diff --git a/jumanji/environments/swarms/search_and_rescue/dynamics.py b/jumanji/environments/swarms/search_and_rescue/dynamics.py index d9abe5b19..63353dcb7 100644 --- a/jumanji/environments/swarms/search_and_rescue/dynamics.py +++ b/jumanji/environments/swarms/search_and_rescue/dynamics.py @@ -17,18 +17,20 @@ import chex import jax +from jumanji.environments.swarms.search_and_rescue.types import TargetState + class TargetDynamics(abc.ABC): @abc.abstractmethod - def __call__(self, key: chex.PRNGKey, target_pos: chex.Array) -> chex.Array: + def __call__(self, key: chex.PRNGKey, targets: TargetState, env_size: float) -> TargetState: """Interface for target position update function. Args: key: random key. - target_pos: Current target positions. + targets: Current target states. Returns: - Updated target positions. + Updated target states. """ @@ -46,16 +48,17 @@ def __init__(self, step_size: float): """ self.step_size = step_size - def __call__(self, key: chex.PRNGKey, target_pos: chex.Array) -> chex.Array: + def __call__(self, key: chex.PRNGKey, targets: TargetState, env_size: float) -> TargetState: """Update target positions. Args: key: random key. - target_pos: Current target positions. + targets: Current target states. Returns: - Updated target positions. + Updated target states. """ - d_pos = jax.random.uniform(key, target_pos.shape) + d_pos = jax.random.uniform(key, targets.pos.shape) d_pos = self.step_size * 2.0 * (d_pos - 0.5) - return target_pos + d_pos + pos = (targets.pos + d_pos) % env_size + return TargetState(pos=pos, vel=targets.vel, found=targets.found) diff --git a/jumanji/environments/swarms/search_and_rescue/env.py b/jumanji/environments/swarms/search_and_rescue/env.py index 267975854..609e8ba68 100644 --- a/jumanji/environments/swarms/search_and_rescue/env.py +++ b/jumanji/environments/swarms/search_and_rescue/env.py @@ -113,7 +113,7 @@ def __init__( searcher_min_speed: float = 0.01, searcher_max_speed: float = 0.02, searcher_view_angle: float = 0.75, - max_steps: int = 400, + time_limit: int = 400, viewer: Optional[Viewer[State]] = None, target_dynamics: Optional[TargetDynamics] = None, generator: Optional[Generator] = None, @@ -136,7 +136,7 @@ def __init__( The view cone of an agent goes from +- of the view angle relative to its heading, e.g. 0.5 would mean searchers have a 90° view angle in total. - max_steps: Maximum number of environment steps allowed for search. + time_limit: Maximum number of environment steps allowed for search. viewer: `Viewer` used for rendering. Defaults to `SearchAndRescueViewer`. target_dynamics: target_dynamics: Target object dynamics model, implemented as a @@ -156,7 +156,7 @@ def __init__( max_speed=searcher_max_speed, view_angle=searcher_view_angle, ) - self.max_steps = max_steps + self.time_limit = time_limit self._target_dynamics = target_dynamics or RandomWalk(0.001) self.generator = generator or RandomGenerator(num_targets=100, num_searchers=2) self._viewer = viewer or SearchAndRescueViewer() @@ -180,7 +180,7 @@ def __repr__(self) -> str: f" - target contact range: {self.target_contact_range}", f" - num vision: {self._observation.num_vision}", f" - agent radius: {self._observation.agent_radius}", - f" - max steps: {self.max_steps}," + f" - time limit: {self.time_limit}," f" - env size: {self.generator.env_size}" f" - target dynamics: {self._target_dynamics.__class__.__name__}", f" - generator: {self.generator.__class__.__name__}", @@ -223,11 +223,12 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser searchers = update_state( key, self.generator.env_size, self.searcher_params, state.searchers, actions ) - # Ensure target positions are wrapped - target_pos = self._target_dynamics(target_key, state.targets.pos) % self.generator.env_size + + targets = self._target_dynamics(target_key, state.targets, self.generator.env_size) + # Searchers return an array of flags of any targets they are in range of, # and that have not already been located, result shape here is (n-searcher, n-targets) - n_targets = target_pos.shape[0] + n_targets = targets.pos.shape[0] targets_found = spatial( utils.searcher_detect_targets, reduction=jnp.logical_or, @@ -238,14 +239,14 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser key, self.searcher_params.view_angle, searchers, - (jnp.arange(n_targets), state.targets), + (jnp.arange(n_targets), targets), pos=searchers.pos, - pos_b=target_pos, + pos_b=targets.pos, env_size=self.generator.env_size, n_targets=n_targets, ) - rewards = self._reward_fn(targets_found) + rewards = self._reward_fn(targets_found, state.step, self.time_limit) targets_found = jnp.any(targets_found, axis=0) # Targets need to remain found if they already have been @@ -253,14 +254,14 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser state = State( searchers=searchers, - targets=TargetState(pos=target_pos, found=targets_found), + targets=TargetState(pos=targets.pos, vel=targets.vel, found=targets_found), key=key, step=state.step + 1, ) observation = self._state_to_observation(state) observation = jax.lax.stop_gradient(observation) timestep = jax.lax.cond( - jnp.logical_or(state.step >= self.max_steps, jnp.all(targets_found)), + jnp.logical_or(state.step >= self.time_limit, jnp.all(targets_found)), termination, transition, rewards, @@ -273,9 +274,13 @@ def _state_to_observation(self, state: State) -> Observation: return Observation( searcher_views=searcher_views, targets_remaining=1.0 - jnp.sum(state.targets.found) / self.generator.num_targets, - time_remaining=1.0 - state.step / (self.max_steps + 1), + step=state.step, ) + @cached_property + def num_agents(self) -> int: + return self.generator.num_searchers + @cached_property def observation_spec(self) -> specs.Spec[Observation]: """Returns the observation spec. @@ -287,7 +292,11 @@ def observation_spec(self) -> specs.Spec[Observation]: observation_spec: Search-and-rescue observation spec """ searcher_views = specs.BoundedArray( - shape=(self.generator.num_searchers, *self._observation.view_shape), + shape=( + self.generator.num_searchers, + self._observation.num_channels, + self._observation.num_vision, + ), minimum=-1.0, maximum=1.0, dtype=float, @@ -298,10 +307,10 @@ def observation_spec(self) -> specs.Spec[Observation]: "ObservationSpec", searcher_views=searcher_views, targets_remaining=specs.BoundedArray( - shape=(), minimum=0.0, maximum=1.0, name="targets_remaining", dtype=float + shape=(), minimum=0.0, maximum=1.0, name="targets_remaining", dtype=jnp.float32 ), - time_remaining=specs.BoundedArray( - shape=(), minimum=0.0, maximum=1.0, name="time_remaining", dtype=float + step=specs.BoundedArray( + shape=(), minimum=0, maximum=self.time_limit, name="step", dtype=jnp.int32 ), ) diff --git a/jumanji/environments/swarms/search_and_rescue/env_test.py b/jumanji/environments/swarms/search_and_rescue/env_test.py index 998045c9b..4f0b051d6 100644 --- a/jumanji/environments/swarms/search_and_rescue/env_test.py +++ b/jumanji/environments/swarms/search_and_rescue/env_test.py @@ -57,7 +57,8 @@ def test_env_init(env: SearchAndRescue, key: chex.PRNGKey) -> None: assert isinstance(timestep.observation, Observation) assert timestep.observation.searcher_views.shape == ( env.generator.num_searchers, - *env._observation.view_shape, + env._observation.num_channels, + env._observation.num_vision, ) assert timestep.step_type == StepType.FIRST @@ -69,8 +70,9 @@ def test_env_step(env: SearchAndRescue, key: chex.PRNGKey, env_size: float) -> N check states (i.e. positions, heading, speeds) all fall inside expected ranges. """ - n_steps = 22 + n_steps = env.time_limit env.generator.env_size = env_size + env.time_limit = 22 def step( carry: Tuple[chex.PRNGKey, State], _: None @@ -108,7 +110,7 @@ def step( def test_env_does_not_smoke(env: SearchAndRescue) -> None: """Test that we can run an episode without any errors.""" - env.max_steps = 10 + env.time_limit = 10 def select_action(action_key: chex.PRNGKey, _state: Observation) -> chex.Array: return jax.random.uniform( @@ -132,7 +134,9 @@ def test_target_detection(env: SearchAndRescue, key: chex.PRNGKey) -> None: searchers=AgentState( pos=jnp.array([[0.5, 0.5]]), heading=jnp.array([jnp.pi]), speed=jnp.array([0.0]) ), - targets=TargetState(pos=jnp.array([[0.54, 0.5]]), found=jnp.array([False])), + targets=TargetState( + pos=jnp.array([[0.54, 0.5]]), vel=jnp.zeros((1, 2)), found=jnp.array([False]) + ), key=key, ) state, timestep = env.step(state, jnp.zeros((1, 2))) @@ -188,7 +192,9 @@ def test_multi_target_detection(env: SearchAndRescue, key: chex.PRNGKey) -> None pos=jnp.array([[0.5, 0.5]]), heading=jnp.array([0.5 * jnp.pi]), speed=jnp.array([0.0]) ), targets=TargetState( - pos=jnp.array([[0.54, 0.5], [0.46, 0.5]]), found=jnp.array([False, False]) + pos=jnp.array([[0.54, 0.5], [0.46, 0.5]]), + vel=jnp.zeros((2, 2)), + found=jnp.array([False, False]), ), key=key, ) diff --git a/jumanji/environments/swarms/search_and_rescue/generator.py b/jumanji/environments/swarms/search_and_rescue/generator.py index 3a3c85251..e0d627db5 100644 --- a/jumanji/environments/swarms/search_and_rescue/generator.py +++ b/jumanji/environments/swarms/search_and_rescue/generator.py @@ -83,11 +83,14 @@ def __call__(self, key: chex.PRNGKey, searcher_params: AgentParams) -> State: target_pos = jax.random.uniform( target_key, (self.num_targets, 2), minval=0.0, maxval=self.env_size ) + target_vel = jnp.zeros((self.num_targets, 2)) state = State( searchers=searcher_state, targets=TargetState( - pos=target_pos, found=jnp.full((self.num_targets,), False, dtype=bool) + pos=target_pos, + vel=target_vel, + found=jnp.full((self.num_targets,), False, dtype=bool), ), key=key, ) diff --git a/jumanji/environments/swarms/search_and_rescue/observations.py b/jumanji/environments/swarms/search_and_rescue/observations.py index f3f917e59..9779c2b33 100644 --- a/jumanji/environments/swarms/search_and_rescue/observations.py +++ b/jumanji/environments/swarms/search_and_rescue/observations.py @@ -27,7 +27,7 @@ class ObservationFn(abc.ABC): def __init__( self, - view_shape: Tuple[int, ...], + num_channels: int, num_vision: int, vision_range: float, view_angle: float, @@ -38,14 +38,14 @@ def __init__( Base class for observation function mapping state to individual agent views. Args: - view_shape: Individual agent view shape. + num_channels: Number of channels in agent view. num_vision: Size of vision array. vision_range: Vision range. view_angle: Agent view angle (as a fraction of pi). agent_radius: Agent/target visual radius. env_size: Environment size. """ - self.view_shape = view_shape + self.num_channels = num_channels self.num_vision = num_vision self.vision_range = vision_range self.view_angle = view_angle @@ -85,7 +85,7 @@ def __init__( env_size: Environment size. """ super().__init__( - (1, num_vision), + 1, num_vision, vision_range, view_angle, @@ -199,7 +199,7 @@ def __init__( self.agent_radius = agent_radius self.env_size = env_size super().__init__( - (2, num_vision), + 2, num_vision, vision_range, view_angle, @@ -333,7 +333,7 @@ def __init__( self.agent_radius = agent_radius self.env_size = env_size super().__init__( - (3, num_vision), + 3, num_vision, vision_range, view_angle, diff --git a/jumanji/environments/swarms/search_and_rescue/observations_test.py b/jumanji/environments/swarms/search_and_rescue/observations_test.py index 3b02b3e42..17b092fa4 100644 --- a/jumanji/environments/swarms/search_and_rescue/observations_test.py +++ b/jumanji/environments/swarms/search_and_rescue/observations_test.py @@ -84,7 +84,9 @@ def test_searcher_view( searchers=AgentState( pos=searcher_positions, heading=searcher_headings, speed=searcher_speed ), - targets=TargetState(pos=jnp.zeros((1, 2)), found=jnp.zeros((1, 2), dtype=bool)), + targets=TargetState( + pos=jnp.zeros((1, 2)), vel=jnp.zeros((1, 2)), found=jnp.zeros((1, 2), dtype=bool) + ), key=key, ) @@ -164,7 +166,9 @@ def test_search_and_target_view_searchers( searchers=AgentState( pos=searcher_positions, heading=searcher_headings, speed=searcher_speed ), - targets=TargetState(pos=jnp.zeros((1, 2)), found=jnp.zeros((1,), dtype=bool)), + targets=TargetState( + pos=jnp.zeros((1, 2)), vel=jnp.zeros((1, 2)), found=jnp.zeros((1,), dtype=bool) + ), key=key, ) @@ -241,6 +245,7 @@ def test_search_and_target_view_targets( searchers=AgentState(pos=searcher_position, heading=searcher_heading, speed=searcher_speed), targets=TargetState( pos=target_position, + vel=jnp.zeros_like(target_position), found=target_found, ), key=key, @@ -328,6 +333,7 @@ def test_search_and_all_target_view_targets( searchers=AgentState(pos=searcher_position, heading=searcher_heading, speed=searcher_speed), targets=TargetState( pos=target_position, + vel=jnp.zeros_like(target_position), found=target_found, ), key=key, diff --git a/jumanji/environments/swarms/search_and_rescue/reward.py b/jumanji/environments/swarms/search_and_rescue/reward.py index 1217b86f7..720adc3fa 100644 --- a/jumanji/environments/swarms/search_and_rescue/reward.py +++ b/jumanji/environments/swarms/search_and_rescue/reward.py @@ -22,7 +22,7 @@ class RewardFn(abc.ABC): """Abstract class for `SearchAndRescue` rewards.""" @abc.abstractmethod - def __call__(self, found_targets: chex.Array) -> chex.Array: + def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array: """The reward function used in the `SearchAndRescue` environment. Args: @@ -41,7 +41,7 @@ class SharedRewardFn(RewardFn): can receive rewards for detecting multiple targets. """ - def __call__(self, found_targets: chex.Array) -> chex.Array: + def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array: rewards = found_targets.astype(float) norms = jnp.sum(rewards, axis=0)[jnp.newaxis] rewards = jnp.where(norms > 0, rewards / norms, rewards) @@ -57,7 +57,25 @@ class IndividualRewardFn(RewardFn): even if a target is detected by multiple agents. """ - def __call__(self, found_targets: chex.Array) -> chex.Array: + def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array: rewards = found_targets.astype(float) rewards = jnp.sum(rewards, axis=1) return rewards + + +class SharedScaledRewardFn(RewardFn): + """ + Calculate per agent rewards from detected targets + + Targets detected by multiple agents share rewards. Agents + can receive rewards for detecting multiple targets. + Rewards are scaled by the current time step. + """ + + def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array: + rewards = found_targets.astype(float) + norms = jnp.sum(rewards, axis=0)[jnp.newaxis] + rewards = jnp.where(norms > 0, rewards / norms, rewards) + rewards = jnp.sum(rewards, axis=1) + scale = (time_limit - step) / time_limit + return scale * rewards diff --git a/jumanji/environments/swarms/search_and_rescue/reward_test.py b/jumanji/environments/swarms/search_and_rescue/reward_test.py index 303b48b5c..e43590871 100644 --- a/jumanji/environments/swarms/search_and_rescue/reward_test.py +++ b/jumanji/environments/swarms/search_and_rescue/reward_test.py @@ -20,14 +20,26 @@ def test_rewards_from_found_targets() -> None: targets_found = jnp.array([[False, True, True], [False, False, True]], dtype=bool) - shared_rewards = reward.SharedRewardFn()(targets_found) + shared_rewards = reward.SharedRewardFn()(targets_found, 0, 10) assert shared_rewards.shape == (2,) assert shared_rewards.dtype == jnp.float32 assert jnp.allclose(shared_rewards, jnp.array([1.5, 0.5])) - individual_rewards = reward.IndividualRewardFn()(targets_found) + individual_rewards = reward.IndividualRewardFn()(targets_found, 0, 10) assert individual_rewards.shape == (2,) assert individual_rewards.dtype == jnp.float32 assert jnp.allclose(individual_rewards, jnp.array([2.0, 1.0])) + + shared_scaled_rewards = reward.SharedScaledRewardFn()(targets_found, 0, 10) + + assert shared_scaled_rewards.shape == (2,) + assert shared_scaled_rewards.dtype == jnp.float32 + assert jnp.allclose(shared_scaled_rewards, jnp.array([1.5, 0.5])) + + shared_scaled_rewards = reward.SharedScaledRewardFn()(targets_found, 10, 10) + + assert shared_scaled_rewards.shape == (2,) + assert shared_scaled_rewards.dtype == jnp.float32 + assert jnp.allclose(shared_scaled_rewards, jnp.array([0.0, 0.0])) diff --git a/jumanji/environments/swarms/search_and_rescue/types.py b/jumanji/environments/swarms/search_and_rescue/types.py index cb924bd60..28c9600bb 100644 --- a/jumanji/environments/swarms/search_and_rescue/types.py +++ b/jumanji/environments/swarms/search_and_rescue/types.py @@ -29,11 +29,13 @@ class TargetState: The state for the rescue targets. pos: 2d position of the target agents + velocity: 2d velocity of the target agents found: Boolean flag indicating if the target has been located by a searcher. """ pos: chex.Array # (num_targets, 2) + vel: chex.Array # (num_targets, 2) found: chex.Array # (num_targets,) @@ -75,5 +77,5 @@ class Observation(NamedTuple): """ searcher_views: chex.Array # (num_searchers, num_vision) - targets_remaining: chex.Array # () - time_remaining: chex.Array # () + targets_remaining: chex.Numeric # () + step: chex.Numeric # () diff --git a/jumanji/environments/swarms/search_and_rescue/utils_test.py b/jumanji/environments/swarms/search_and_rescue/utils_test.py index 0f7328c43..018e895a7 100644 --- a/jumanji/environments/swarms/search_and_rescue/utils_test.py +++ b/jumanji/environments/swarms/search_and_rescue/utils_test.py @@ -30,14 +30,18 @@ def test_random_walk_dynamics(key: chex.PRNGKey) -> None: n_targets = 50 - s0 = jnp.full((n_targets, 2), 0.5) + pos_0 = jnp.full((n_targets, 2), 0.5) + + s0 = TargetState( + pos=pos_0, vel=jnp.zeros((n_targets, 2)), found=jnp.zeros((n_targets,), dtype=bool) + ) dynamics = RandomWalk(0.1) assert isinstance(dynamics, TargetDynamics) - s1 = dynamics(key, s0) + s1 = dynamics(key, s0, 1.0) - assert s1.shape == (n_targets, 2) - assert jnp.all(jnp.abs(s0 - s1) < 0.1) + assert s1.pos.shape == (n_targets, 2) + assert jnp.all(jnp.abs(s0.pos - s1.pos) < 0.1) @pytest.mark.parametrize( @@ -66,6 +70,7 @@ def test_target_found( ) -> None: target = TargetState( pos=jnp.zeros((2,)), + vel=jnp.zeros((2,)), found=target_state, ) From 5c509c741ad063b0aae77f18edcf288d50a0182e Mon Sep 17 00:00:00 2001 From: zombie-einstein <13398815+zombie-einstein@users.noreply.github.com> Date: Wed, 11 Dec 2024 11:48:38 +0000 Subject: [PATCH 13/19] Pass shape information to timesteps (#11) --- jumanji/environments/swarms/search_and_rescue/env.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jumanji/environments/swarms/search_and_rescue/env.py b/jumanji/environments/swarms/search_and_rescue/env.py index 609e8ba68..b1c9d65e0 100644 --- a/jumanji/environments/swarms/search_and_rescue/env.py +++ b/jumanji/environments/swarms/search_and_rescue/env.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from functools import cached_property +from functools import cached_property, partial from typing import Optional, Sequence, Tuple import chex @@ -200,7 +200,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: timestep: TimeStep with individual search agent views. """ state = self.generator(key, self.searcher_params) - timestep = restart(observation=self._state_to_observation(state)) + timestep = restart(observation=self._state_to_observation(state), shape=(self.num_agents,)) return state, timestep def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Observation]]: @@ -262,8 +262,8 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser observation = jax.lax.stop_gradient(observation) timestep = jax.lax.cond( jnp.logical_or(state.step >= self.time_limit, jnp.all(targets_found)), - termination, - transition, + partial(termination, shape=(self.num_agents,)), + partial(transition, shape=(self.num_agents,)), rewards, observation, ) From 8acf24257ebccc4448321f1f3700c19983263234 Mon Sep 17 00:00:00 2001 From: zombie-einstein <13398815+zombie-einstein@users.noreply.github.com> Date: Wed, 11 Dec 2024 23:08:46 +0000 Subject: [PATCH 14/19] test: extend tests and docs (#12) * Update docstrings * Update tests * Update environment readme --- docs/environments/search_and_rescue.md | 36 +++++++++------ jumanji/environments/swarms/common/updates.py | 16 +++---- jumanji/environments/swarms/common/viewer.py | 2 +- .../swarms/search_and_rescue/conftest.py | 46 ++++++++++++++++++- .../swarms/search_and_rescue/dynamics.py | 20 ++++---- .../swarms/search_and_rescue/env.py | 44 ++++++++++-------- .../swarms/search_and_rescue/env_test.py | 18 +++++--- .../swarms/search_and_rescue/generator.py | 2 +- .../swarms/search_and_rescue/observations.py | 20 ++++++-- .../search_and_rescue/observations_test.py | 5 +- .../swarms/search_and_rescue/reward.py | 3 +- .../swarms/search_and_rescue/reward_test.py | 24 +++++++--- .../swarms/search_and_rescue/utils.py | 2 +- .../swarms/search_and_rescue/utils_test.py | 2 + 14 files changed, 162 insertions(+), 78 deletions(-) diff --git a/docs/environments/search_and_rescue.md b/docs/environments/search_and_rescue.md index 8ee4ce783..d7ba33a89 100644 --- a/docs/environments/search_and_rescue.md +++ b/docs/environments/search_and_rescue.md @@ -2,29 +2,34 @@ [//]: # (TODO: Add animated plot) -Multi-agent environment, modelling a group of agents searching the environment +Multi-agent environment, modelling a group of agents searching a 2d environment for multiple targets. Agents are individually rewarded for finding a target that has not previously been detected. -Each agent visualises a local region around it, creating a simple segmented view -of locations of other agents in the vicinity. The environment is updated in the -following sequence: +Each agent visualises a local region around itself, represented as a simple segmented +view of locations of other agents and targets in the vicinity. The environment +is updated in the following sequence: - The velocity of searching agents are updated, and consequently their positions. - The positions of targets are updated. -- Agents are rewarded for being within a fixed range of targets, and the target - being within its view cone. - Targets within detection range and an agents view cone are marked as found. +- Agents are rewarded for locating previously unfound targets. - Local views of the environment are generated for each search agent. The agents are allotted a fixed number of steps to locate the targets. The search -space is a uniform space with unit dimensions, and wrapped at the boundaries. +space is a uniform square space, wrapped at the boundaries. + +Many aspects of the environment can be customised: + +- Agent observations can include targets as well as other searcher agents. +- Rewards can be shared by agents, or can be treated completely individually for individual agents. +- Target dynamics can be customised to model various search scenarios. ## Observations -- `searcher_views`: jax array (float) of shape `(num_searchers, num_vision)`. Each agent - generates an independent observation, an array of values representing the distance - along a ray from the agent to the nearest neighbour, with each cell representing a +- `searcher_views`: jax array (float) of shape `(num_searchers, channels, num_vision)`. + Each agent generates an independent observation, an array of values representing the distance + along a ray from the agent to the nearest neighbour or target, with each cell representing a ray angle (with `num_vision` rays evenly distributed over the agents field of vision). For example if an agent sees another agent straight ahead and `num_vision = 5` then the observation array could be @@ -34,11 +39,12 @@ space is a uniform space with unit dimensions, and wrapped at the boundaries. ``` where `-1.0` indicates there is no agents along that ray, and `0.5` is the normalised - distance to the other agent. + distance to the other agent. Channels in the segmented view are used to differentiate + between different agents/targets and can be customised. By default, the view has three + channels representing other agents, found targets, and unfound targets. - `targets_remaining`: float in the range `[0, 1]`. The normalised number of targets remaining to be detected (i.e. 1.0 when no targets have been found). -- `time_remaining`: float in the range `[0, 1]`. The normalised number of steps remaining - to locate the targets (i.e. 0.0 at the end of the episode). +- `Step`: int in the range `[0, time_limit]`. The current simulation step. ## Actions @@ -64,4 +70,6 @@ Once applied, agent speeds are clipped to velocities within a fixed range of spe ## Rewards Jax array (float) of `(num_searchers,)`. Rewards are generated for each agent individually. -Agents are rewarded 1.0 for locating a target that has not already been detected. +Agents are rewarded +1 for locating a target that has not already been detected. It is possible +for multiple agents to detect a target inside a step, as such rewards can either be shared +by the locating agents, or each agent can get the full reward. diff --git a/jumanji/environments/swarms/common/updates.py b/jumanji/environments/swarms/common/updates.py index 6df21755f..2db85f2f1 100644 --- a/jumanji/environments/swarms/common/updates.py +++ b/jumanji/environments/swarms/common/updates.py @@ -24,15 +24,15 @@ @esquilax.transforms.amap def update_velocity( - _: chex.PRNGKey, + _key: chex.PRNGKey, params: types.AgentParams, x: Tuple[chex.Array, types.AgentState], -) -> Tuple[float, float]: +) -> Tuple[chex.Numeric, chex.Numeric]: """ Get the updated agent heading and speeds from actions Args: - _: Dummy JAX random key. + _key: Dummy JAX random key. params: Agent parameters. x: Agent rotation and acceleration actions. @@ -105,10 +105,10 @@ def update_state( def view_reduction(view_a: chex.Array, view_b: chex.Array) -> chex.Array: """ - Binary view reduction function. + Binary view reduction function for use in Esquilax spatial transformation. Handles reduction where a value of -1.0 indicates no - agent in view-range. Returns the min value of they + agent in view-range. Returns the min value if they are both positive, but the max value if one or both of the values is -1.0. @@ -137,7 +137,7 @@ def angular_width( env_size: float, ) -> Tuple[chex.Array, chex.Array, chex.Array]: """ - Get the normalised distance, and left and right angles to another agent. + Get the normalised distance, and angles to edges of another agent. Args: viewing_pos: Co-ordinates of the viewing agent @@ -175,10 +175,10 @@ def view( Simple view model where the agents view angle is subdivided into an array of values representing the distance from - the agent along a rays from the agent, with rays evenly distributed. + the agent along a rays from the agent, with rays evenly distributed across the agents field of view. The limit of vision is set at 1.0. The default value if no object is within range is -1.0. - Currently, this model assumes the viewed objects are circular. + Currently, this model assumes the viewed agent/objects are circular. Args: _key: Dummy JAX random key, required by esquilax API, but diff --git a/jumanji/environments/swarms/common/viewer.py b/jumanji/environments/swarms/common/viewer.py index 16bb197e6..4fc15c88b 100644 --- a/jumanji/environments/swarms/common/viewer.py +++ b/jumanji/environments/swarms/common/viewer.py @@ -48,7 +48,7 @@ def draw_agents(ax: Axes, agent_states: AgentState, color: str) -> Quiver: def format_plot( fig: Figure, ax: Axes, env_dims: Tuple[float, float], border: float = 0.01 ) -> Tuple[Figure, Axes]: - """Format a flock/swarm plot, remove ticks and bound to the unit interval + """Format a flock/swarm plot, remove ticks and bound to the environment dimensions. Args: fig: Matplotlib figure. diff --git a/jumanji/environments/swarms/search_and_rescue/conftest.py b/jumanji/environments/swarms/search_and_rescue/conftest.py index 6b63645aa..8158bacac 100644 --- a/jumanji/environments/swarms/search_and_rescue/conftest.py +++ b/jumanji/environments/swarms/search_and_rescue/conftest.py @@ -16,7 +16,7 @@ import jax.random import pytest -from jumanji.environments.swarms.search_and_rescue import SearchAndRescue +from jumanji.environments.swarms.search_and_rescue import SearchAndRescue, observations @pytest.fixture @@ -32,6 +32,50 @@ def env() -> SearchAndRescue: ) +class FixtureRequest: + """Just used for typing""" + + param: observations.ObservationFn + + +@pytest.fixture( + params=[ + observations.AgentObservationFn( + num_vision=32, + vision_range=0.1, + view_angle=0.5, + agent_radius=0.01, + env_size=1.0, + ), + observations.AgentAndTargetObservationFn( + num_vision=32, + vision_range=0.1, + view_angle=0.5, + agent_radius=0.01, + env_size=1.0, + ), + observations.AgentAndAllTargetObservationFn( + num_vision=32, + vision_range=0.1, + view_angle=0.5, + agent_radius=0.01, + env_size=1.0, + ), + ] +) +def multi_obs_env(request: FixtureRequest) -> SearchAndRescue: + return SearchAndRescue( + target_contact_range=0.05, + searcher_max_rotate=0.2, + searcher_max_accelerate=0.01, + searcher_min_speed=0.01, + searcher_max_speed=0.05, + searcher_view_angle=0.5, + time_limit=10, + observation=request.param, + ) + + @pytest.fixture def key() -> chex.PRNGKey: return jax.random.PRNGKey(101) diff --git a/jumanji/environments/swarms/search_and_rescue/dynamics.py b/jumanji/environments/swarms/search_and_rescue/dynamics.py index 63353dcb7..e3d9b47ef 100644 --- a/jumanji/environments/swarms/search_and_rescue/dynamics.py +++ b/jumanji/environments/swarms/search_and_rescue/dynamics.py @@ -23,11 +23,15 @@ class TargetDynamics(abc.ABC): @abc.abstractmethod def __call__(self, key: chex.PRNGKey, targets: TargetState, env_size: float) -> TargetState: - """Interface for target position update function. + """Interface for target state update function. + + NOTE: Target positions should be bound to environment + area (generally wrapped around at the boundaries). Args: - key: random key. + key: Random key. targets: Current target states. + env_size: Environment size. Returns: Updated target states. @@ -37,23 +41,23 @@ def __call__(self, key: chex.PRNGKey, targets: TargetState, env_size: float) -> class RandomWalk(TargetDynamics): def __init__(self, step_size: float): """ - Random walk target dynamics. + Simple random walk target dynamics. - Target positions are updated with random - steps, sampled uniformly from the range - [-step-size, step-size]. + Target positions are updated with random steps, sampled uniformly + from the range `[-step-size, step-size]`. Args: - step_size: Maximum random step-size + step_size: Maximum random step-size in each axis. """ self.step_size = step_size def __call__(self, key: chex.PRNGKey, targets: TargetState, env_size: float) -> TargetState: - """Update target positions. + """Update target state. Args: key: random key. targets: Current target states. + env_size: Environment size. Returns: Updated target states. diff --git a/jumanji/environments/swarms/search_and_rescue/env.py b/jumanji/environments/swarms/search_and_rescue/env.py index b1c9d65e0..c56c08374 100644 --- a/jumanji/environments/swarms/search_and_rescue/env.py +++ b/jumanji/environments/swarms/search_and_rescue/env.py @@ -29,7 +29,7 @@ from jumanji.environments.swarms.search_and_rescue.dynamics import RandomWalk, TargetDynamics from jumanji.environments.swarms.search_and_rescue.generator import Generator, RandomGenerator from jumanji.environments.swarms.search_and_rescue.observations import ( - AgentAndTargetObservationFn, + AgentAndAllTargetObservationFn, ObservationFn, ) from jumanji.environments.swarms.search_and_rescue.reward import RewardFn, SharedRewardFn @@ -46,8 +46,9 @@ class SearchAndRescue(Environment): for a set of targets on a 2d environment. Agents are rewarded (individually) for coming within a fixed range of a target that has not already been detected. Agents visualise their local environment - (i.e. the location of other agents) via a simple segmented view model. - The environment consists of a uniform space with wrapped boundaries. + (i.e. the location of other agents and targets) via a simple segmented + view model. The environment area is a uniform square space with wrapped + boundaries. An episode will terminate if all targets have been located by the team of searching agents. @@ -58,12 +59,12 @@ class SearchAndRescue(Environment): channels can be used to differentiate between agents and targets. Each entry in the view indicates the distant to another agent/target along a ray from the agent, and is -1.0 if nothing is in range along the ray. - The view model can be customised using an `ObservationFn` implementation. + The view model can be customised using an `ObservationFn` implementation, e.g. + the view can include all agents and targets, or just other agents. targets_remaining: (float) Number of targets remaining to be found from the total scaled to the range [0, 1] (i.e. a value of 1.0 indicates all the targets are still to be found). - time_remaining: (float) Steps remaining to find agents, scaled to the - range [0,1] (i.e. the value is 0 when time runs out). + step: (int) current simulation step. - action: jax array (float) of shape (num_searchers, 2) Array of individual agent actions. Each agents actions rotate and @@ -80,13 +81,14 @@ class SearchAndRescue(Environment): - state: `State` - searchers: `AgentState` - - pos: jax array (float) of shape (num_searchers, 2) in the range [0, 1]. + - pos: jax array (float) of shape (num_searchers, 2) in the range [0, env_size]. - heading: jax array (float) of shape (num_searcher,) in the range [0, 2pi]. - speed: jax array (float) of shape (num_searchers,) in the range [min_speed, max_speed]. - targets: `TargetState` - - pos: jax array (float) of shape (num_targets, 2) in the range [0, 1]. + - pos: jax array (float) of shape (num_targets, 2) in the range [0, env_size]. + - vel: jax array (float) of shape (num_targets, 2). - found: jax array (bool) of shape (num_targets,) flag indicating if target has been located by an agent. - key: jax array (uint32) of shape (2,) @@ -127,8 +129,8 @@ def __init__( searcher_max_rotate: Maximum rotation searcher agents can turn within a step. Should be a value from [0,1] representing a fraction of pi radians. - searcher_max_accelerate: Maximum acceleration/deceleration - a searcher agent can apply within a step. + searcher_max_accelerate: Magnitude of the maximum + acceleration/deceleration a searcher agent can apply within a step. searcher_min_speed: Minimum speed a searcher agent can move at. searcher_max_speed: Maximum speed a searcher agent can move at. searcher_view_angle: Searcher agent local view angle. Should be @@ -145,8 +147,11 @@ def __init__( with 20 targets and 10 searchers. reward_fn: Reward aggregation function. Defaults to `SharedRewardFn` where agents share rewards if they locate a target simultaneously. + observation: Agent observation view generation function. Defaults to + `AgentAndAllTargetObservationFn` where all targets (found and unfound) + and other ogents are included in the generated view. """ - # self.searcher_vision_range = searcher_vision_range + self.target_contact_range = target_contact_range self.searcher_params = AgentParams( @@ -161,7 +166,7 @@ def __init__( self.generator = generator or RandomGenerator(num_targets=100, num_searchers=2) self._viewer = viewer or SearchAndRescueViewer() self._reward_fn = reward_fn or SharedRewardFn() - self._observation = observation or AgentAndTargetObservationFn( + self._observation = observation or AgentAndAllTargetObservationFn( num_vision=64, vision_range=0.1, view_angle=searcher_view_angle, @@ -190,7 +195,7 @@ def __repr__(self) -> str: ) def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]: - """Initialise searcher positions and velocities, and target positions. + """Initialise searcher and target initial states. Args: key: Random key used to reset the environment. @@ -217,7 +222,7 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser state: Updated searcher and target positions and velocities. timestep: Transition timestep with individual agent local observations. """ - # Note: only one new key is needed for the targets, as all other + # Note: only one new key is needed for the target updates, as all other # keys are just dummy values required by Esquilax key, target_key = jax.random.split(state.key, num=2) searchers = update_state( @@ -228,22 +233,21 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser # Searchers return an array of flags of any targets they are in range of, # and that have not already been located, result shape here is (n-searcher, n-targets) - n_targets = targets.pos.shape[0] targets_found = spatial( utils.searcher_detect_targets, reduction=jnp.logical_or, - default=jnp.zeros((n_targets,), dtype=bool), + default=jnp.zeros((self.generator.num_targets,), dtype=bool), i_range=self.target_contact_range, dims=self.generator.env_size, )( key, self.searcher_params.view_angle, searchers, - (jnp.arange(n_targets), targets), + (jnp.arange(self.generator.num_targets), targets), pos=searchers.pos, pos_b=targets.pos, env_size=self.generator.env_size, - n_targets=n_targets, + n_targets=self.generator.num_targets, ) rewards = self._reward_fn(targets_found, state.step, self.time_limit) @@ -352,14 +356,14 @@ def render(self, state: State) -> None: """Render a frame of the environment for a given state using matplotlib. Args: - state: State object containing the current dynamics of the environment. + state: State object containing the current state of the environment. """ self._viewer.render(state) def animate( self, states: Sequence[State], - interval: int = 200, + interval: int = 100, save_path: Optional[str] = None, ) -> FuncAnimation: """Create an animation from a sequence of environment states. diff --git a/jumanji/environments/swarms/search_and_rescue/env_test.py b/jumanji/environments/swarms/search_and_rescue/env_test.py index 4f0b051d6..0a6cdcf58 100644 --- a/jumanji/environments/swarms/search_and_rescue/env_test.py +++ b/jumanji/environments/swarms/search_and_rescue/env_test.py @@ -44,10 +44,11 @@ def test_env_init(env: SearchAndRescue, key: chex.PRNGKey) -> None: assert isinstance(state.searchers, AgentState) assert state.searchers.pos.shape == (env.generator.num_searchers, 2) assert state.searchers.speed.shape == (env.generator.num_searchers,) - assert state.searchers.speed.shape == (env.generator.num_searchers,) + assert state.searchers.heading.shape == (env.generator.num_searchers,) assert isinstance(state.targets, TargetState) assert state.targets.pos.shape == (env.generator.num_targets, 2) + assert state.targets.vel.shape == (env.generator.num_targets, 2) assert state.targets.found.shape == (env.generator.num_targets,) assert jnp.array_equal( state.targets.found, jnp.full((env.generator.num_targets,), False, dtype=bool) @@ -61,6 +62,7 @@ def test_env_init(env: SearchAndRescue, key: chex.PRNGKey) -> None: env._observation.num_vision, ) assert timestep.step_type == StepType.FIRST + assert timestep.reward.shape == (env.generator.num_searchers,) @pytest.mark.parametrize("env_size", [1.0, 0.2]) @@ -108,26 +110,27 @@ def step( assert jnp.all((0.0 <= state_history.targets.pos) & (state_history.targets.pos <= env_size)) -def test_env_does_not_smoke(env: SearchAndRescue) -> None: +def test_env_does_not_smoke(multi_obs_env: SearchAndRescue) -> None: """Test that we can run an episode without any errors.""" - env.time_limit = 10 + multi_obs_env.time_limit = 10 def select_action(action_key: chex.PRNGKey, _state: Observation) -> chex.Array: return jax.random.uniform( - action_key, (env.generator.num_searchers, 2), minval=-1.0, maxval=1.0 + action_key, (multi_obs_env.generator.num_searchers, 2), minval=-1.0, maxval=1.0 ) - check_env_does_not_smoke(env, select_action=select_action) + check_env_does_not_smoke(multi_obs_env, select_action=select_action) -def test_env_specs_do_not_smoke(env: SearchAndRescue) -> None: +def test_env_specs_do_not_smoke(multi_obs_env: SearchAndRescue) -> None: """Test that we can access specs without any errors.""" - check_env_specs_does_not_smoke(env) + check_env_specs_does_not_smoke(multi_obs_env) def test_target_detection(env: SearchAndRescue, key: chex.PRNGKey) -> None: # Keep targets in one location env._target_dynamics = RandomWalk(step_size=0.0) + env.generator.num_targets = 1 # Agent facing wrong direction should not see target state = State( @@ -185,6 +188,7 @@ def test_multi_target_detection(env: SearchAndRescue, key: chex.PRNGKey) -> None max_speed=0.05, view_angle=0.25, ) + env.generator.num_targets = 2 # Agent facing wrong direction should not see target state = State( diff --git a/jumanji/environments/swarms/search_and_rescue/generator.py b/jumanji/environments/swarms/search_and_rescue/generator.py index e0d627db5..2425475f6 100644 --- a/jumanji/environments/swarms/search_and_rescue/generator.py +++ b/jumanji/environments/swarms/search_and_rescue/generator.py @@ -41,7 +41,7 @@ def __call__(self, key: chex.PRNGKey, searcher_params: AgentParams) -> State: Args: key: random key. - searcher_params: Searcher `AgentParams`. + searcher_params: Searcher aagent `AgentParams`. Returns: Initial agent `State`. diff --git a/jumanji/environments/swarms/search_and_rescue/observations.py b/jumanji/environments/swarms/search_and_rescue/observations.py index 9779c2b33..4abd9c6d6 100644 --- a/jumanji/environments/swarms/search_and_rescue/observations.py +++ b/jumanji/environments/swarms/search_and_rescue/observations.py @@ -37,6 +37,11 @@ def __init__( """ Base class for observation function mapping state to individual agent views. + Maps states to an array of individual local agent views of + the environment, with shape (n-agents, n-channels, n-vision). + Channels can be used to differentiate between agent types or + statuses. + Args: num_channels: Number of channels in agent view. num_vision: Size of vision array. @@ -215,7 +220,9 @@ def __call__(self, state: State) -> chex.Array: state: Current simulation state Returns: - Array of individual agent views + Array of individual agent views of shape + (n-agents, 2, n-vision). Other agents are shown + in channel 0, and located targets 1. """ searcher_views = spatial( view, @@ -316,10 +323,10 @@ def __init__( env_size: float, ) -> None: """ - Vision model that contains other agents, and found targets. + Vision model that contains other agents, and all targets. - Searchers and targets are visualised as individual channels. - Targets are only included if they have been located already. + Searchers and targets are visualised as individual channels, + with found and unfound targets also shown on different channels. Args: num_vision: Size of vision array. @@ -349,7 +356,10 @@ def __call__(self, state: State) -> chex.Array: state: Current simulation state Returns: - Array of individual agent views + Array of individual agent views of shape + (n-agents, 3, n-vision). Other agents are shown + in channel 0, located targets 1, and un-located + targets at index 2. """ searcher_views = spatial( view, diff --git a/jumanji/environments/swarms/search_and_rescue/observations_test.py b/jumanji/environments/swarms/search_and_rescue/observations_test.py index 17b092fa4..a95dcd271 100644 --- a/jumanji/environments/swarms/search_and_rescue/observations_test.py +++ b/jumanji/environments/swarms/search_and_rescue/observations_test.py @@ -19,7 +19,7 @@ import pytest from jumanji.environments.swarms.common.types import AgentState -from jumanji.environments.swarms.search_and_rescue import SearchAndRescue, observations +from jumanji.environments.swarms.search_and_rescue import observations from jumanji.environments.swarms.search_and_rescue.types import State, TargetState VISION_RANGE = 0.2 @@ -65,7 +65,6 @@ ) def test_searcher_view( key: chex.PRNGKey, - # env: SearchAndRescue, searcher_positions: List[List[float]], searcher_headings: List[float], env_size: float, @@ -222,7 +221,6 @@ def test_search_and_target_view_searchers( ) def test_search_and_target_view_targets( key: chex.PRNGKey, - env: SearchAndRescue, searcher_position: List[float], searcher_heading: float, target_position: List[float], @@ -310,7 +308,6 @@ def test_search_and_target_view_targets( ) def test_search_and_all_target_view_targets( key: chex.PRNGKey, - env: SearchAndRescue, searcher_position: List[float], searcher_heading: float, target_position: List[float], diff --git a/jumanji/environments/swarms/search_and_rescue/reward.py b/jumanji/environments/swarms/search_and_rescue/reward.py index 720adc3fa..40c38b9d4 100644 --- a/jumanji/environments/swarms/search_and_rescue/reward.py +++ b/jumanji/environments/swarms/search_and_rescue/reward.py @@ -69,7 +69,8 @@ class SharedScaledRewardFn(RewardFn): Targets detected by multiple agents share rewards. Agents can receive rewards for detecting multiple targets. - Rewards are scaled by the current time step. + Rewards are linearly scaled by the current time step such that + rewards are 0 at the final step. """ def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array: diff --git a/jumanji/environments/swarms/search_and_rescue/reward_test.py b/jumanji/environments/swarms/search_and_rescue/reward_test.py index e43590871..d351fc259 100644 --- a/jumanji/environments/swarms/search_and_rescue/reward_test.py +++ b/jumanji/environments/swarms/search_and_rescue/reward_test.py @@ -11,34 +11,44 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import chex import jax.numpy as jnp +import pytest from jumanji.environments.swarms.search_and_rescue import reward -def test_rewards_from_found_targets() -> None: - targets_found = jnp.array([[False, True, True], [False, False, True]], dtype=bool) +@pytest.fixture +def target_states() -> chex.Array: + return jnp.array([[False, True, True], [False, False, True]], dtype=bool) + - shared_rewards = reward.SharedRewardFn()(targets_found, 0, 10) +def test_shared_rewards(target_states: chex.Array) -> None: + shared_rewards = reward.SharedRewardFn()(target_states, 0, 10) assert shared_rewards.shape == (2,) assert shared_rewards.dtype == jnp.float32 assert jnp.allclose(shared_rewards, jnp.array([1.5, 0.5])) - individual_rewards = reward.IndividualRewardFn()(targets_found, 0, 10) + +def test_individual_rewards(target_states: chex.Array) -> None: + individual_rewards = reward.IndividualRewardFn()(target_states, 0, 10) assert individual_rewards.shape == (2,) assert individual_rewards.dtype == jnp.float32 assert jnp.allclose(individual_rewards, jnp.array([2.0, 1.0])) - shared_scaled_rewards = reward.SharedScaledRewardFn()(targets_found, 0, 10) + +def test_scaled_rewards(target_states: chex.Array) -> None: + reward_fn = reward.SharedScaledRewardFn() + + shared_scaled_rewards = reward_fn(target_states, 0, 10) assert shared_scaled_rewards.shape == (2,) assert shared_scaled_rewards.dtype == jnp.float32 assert jnp.allclose(shared_scaled_rewards, jnp.array([1.5, 0.5])) - shared_scaled_rewards = reward.SharedScaledRewardFn()(targets_found, 10, 10) + shared_scaled_rewards = reward_fn(target_states, 10, 10) assert shared_scaled_rewards.shape == (2,) assert shared_scaled_rewards.dtype == jnp.float32 diff --git a/jumanji/environments/swarms/search_and_rescue/utils.py b/jumanji/environments/swarms/search_and_rescue/utils.py index c6f581bdd..e45f82343 100644 --- a/jumanji/environments/swarms/search_and_rescue/utils.py +++ b/jumanji/environments/swarms/search_and_rescue/utils.py @@ -62,7 +62,7 @@ def searcher_detect_targets( Return array of flags indicating if a target has been located Sets the flag at the target index if the target is within the - searchers view cone, and had not already been detected. + searchers view cone, and has not already been detected. Args: _key: Dummy random key (required by Esquilax). diff --git a/jumanji/environments/swarms/search_and_rescue/utils_test.py b/jumanji/environments/swarms/search_and_rescue/utils_test.py index 018e895a7..be909074e 100644 --- a/jumanji/environments/swarms/search_and_rescue/utils_test.py +++ b/jumanji/environments/swarms/search_and_rescue/utils_test.py @@ -40,7 +40,9 @@ def test_random_walk_dynamics(key: chex.PRNGKey) -> None: assert isinstance(dynamics, TargetDynamics) s1 = dynamics(key, s0, 1.0) + assert isinstance(s1, TargetState) assert s1.pos.shape == (n_targets, 2) + assert jnp.array_equal(s0.found, s1.found) assert jnp.all(jnp.abs(s0.pos - s1.pos) < 0.1) From 1792aa6f1d3d93197c9b0dad19db15e96c706979 Mon Sep 17 00:00:00 2001 From: zombie-einstein <13398815+zombie-einstein@users.noreply.github.com> Date: Thu, 12 Dec 2024 09:52:32 +0000 Subject: [PATCH 15/19] fix: unpin jax requirement --- requirements/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 7e0d7a09e..34a8c3253 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -3,7 +3,7 @@ dm-env>=1.5 esquilax>=1.0.3 gymnasium>=1.0 huggingface-hub -jax>=0.2.26,<0.4.36 +jax>=0.2.26 matplotlib~=3.7.4 numpy>=1.19.5 Pillow>=9.0.0 From 1e66e786f6fea9b6ab6f33d0baf1ee94eead01cf Mon Sep 17 00:00:00 2001 From: zombie-einstein <13398815+zombie-einstein@users.noreply.github.com> Date: Thu, 12 Dec 2024 11:00:15 +0000 Subject: [PATCH 16/19] Include agent positions in observation (#13) --- .../swarms/search_and_rescue/env.py | 10 +++++++++- .../swarms/search_and_rescue/env_test.py | 17 ++++++++++------- .../swarms/search_and_rescue/types.py | 1 + 3 files changed, 20 insertions(+), 8 deletions(-) diff --git a/jumanji/environments/swarms/search_and_rescue/env.py b/jumanji/environments/swarms/search_and_rescue/env.py index c56c08374..6b10fb897 100644 --- a/jumanji/environments/swarms/search_and_rescue/env.py +++ b/jumanji/environments/swarms/search_and_rescue/env.py @@ -279,6 +279,7 @@ def _state_to_observation(self, state: State) -> Observation: searcher_views=searcher_views, targets_remaining=1.0 - jnp.sum(state.targets.found) / self.generator.num_targets, step=state.step, + positions=state.searchers.pos / self.generator.env_size, ) @cached_property @@ -297,7 +298,7 @@ def observation_spec(self) -> specs.Spec[Observation]: """ searcher_views = specs.BoundedArray( shape=( - self.generator.num_searchers, + self.num_agents, self._observation.num_channels, self._observation.num_vision, ), @@ -316,6 +317,13 @@ def observation_spec(self) -> specs.Spec[Observation]: step=specs.BoundedArray( shape=(), minimum=0, maximum=self.time_limit, name="step", dtype=jnp.int32 ), + positions=specs.BoundedArray( + shape=(self.num_agents, 2), + minimum=0.0, + maximum=1.0, + name="positions", + dtype=jnp.float32, + ), ) @cached_property diff --git a/jumanji/environments/swarms/search_and_rescue/env_test.py b/jumanji/environments/swarms/search_and_rescue/env_test.py index 0a6cdcf58..743546194 100644 --- a/jumanji/environments/swarms/search_and_rescue/env_test.py +++ b/jumanji/environments/swarms/search_and_rescue/env_test.py @@ -65,7 +65,7 @@ def test_env_init(env: SearchAndRescue, key: chex.PRNGKey) -> None: assert timestep.reward.shape == (env.generator.num_searchers,) -@pytest.mark.parametrize("env_size", [1.0, 0.2]) +@pytest.mark.parametrize("env_size", [1.0, 0.2, 10.0]) def test_env_step(env: SearchAndRescue, key: chex.PRNGKey, env_size: float) -> None: """ Run several steps of the environment with random actions and @@ -81,9 +81,7 @@ def step( ) -> Tuple[Tuple[chex.PRNGKey, State], Tuple[State, TimeStep[Observation]]]: k, state = carry k, k_search = jax.random.split(k) - actions = jax.random.uniform( - k_search, (env.generator.num_searchers, 2), minval=-1.0, maxval=1.0 - ) + actions = jax.random.uniform(k_search, (env.num_agents, 2), minval=-1.0, maxval=1.0) new_state, timestep = env.step(state, actions) return (k, new_state), (state, timestep) @@ -94,14 +92,14 @@ def step( assert isinstance(state_history, State) - assert state_history.searchers.pos.shape == (n_steps, env.generator.num_searchers, 2) + assert state_history.searchers.pos.shape == (n_steps, env.num_agents, 2) assert jnp.all((0.0 <= state_history.searchers.pos) & (state_history.searchers.pos <= env_size)) - assert state_history.searchers.speed.shape == (n_steps, env.generator.num_searchers) + assert state_history.searchers.speed.shape == (n_steps, env.num_agents) assert jnp.all( (env.searcher_params.min_speed <= state_history.searchers.speed) & (state_history.searchers.speed <= env.searcher_params.max_speed) ) - assert state_history.searchers.speed.shape == (n_steps, env.generator.num_searchers) + assert state_history.searchers.speed.shape == (n_steps, env.num_agents) assert jnp.all( (0.0 <= state_history.searchers.heading) & (state_history.searchers.heading <= 2.0 * jnp.pi) ) @@ -109,6 +107,11 @@ def step( assert state_history.targets.pos.shape == (n_steps, env.generator.num_targets, 2) assert jnp.all((0.0 <= state_history.targets.pos) & (state_history.targets.pos <= env_size)) + assert timesteps.observation.positions.shape == (n_steps, env.num_agents, 2) + assert jnp.all( + (0.0 <= timesteps.observation.positions) & (timesteps.observation.positions <= 1.0) + ) + def test_env_does_not_smoke(multi_obs_env: SearchAndRescue) -> None: """Test that we can run an episode without any errors.""" diff --git a/jumanji/environments/swarms/search_and_rescue/types.py b/jumanji/environments/swarms/search_and_rescue/types.py index 28c9600bb..919c6c55f 100644 --- a/jumanji/environments/swarms/search_and_rescue/types.py +++ b/jumanji/environments/swarms/search_and_rescue/types.py @@ -79,3 +79,4 @@ class Observation(NamedTuple): searcher_views: chex.Array # (num_searchers, num_vision) targets_remaining: chex.Numeric # () step: chex.Numeric # () + positions: chex.Array # (num_searchers, 2) From 407ff795beffa592be54c4dd2c922b0978a76197 Mon Sep 17 00:00:00 2001 From: zombie-einstein <13398815+zombie-einstein@users.noreply.github.com> Date: Fri, 27 Dec 2024 20:58:05 +0000 Subject: [PATCH 17/19] Upgrade Esquilax and remove unused random keys (#14) --- .../environments/swarms/common/common_test.py | 13 ++---- jumanji/environments/swarms/common/updates.py | 11 +---- .../swarms/search_and_rescue/env.py | 10 ++--- .../swarms/search_and_rescue/observations.py | 45 ++++++++----------- .../swarms/search_and_rescue/utils.py | 2 - .../swarms/search_and_rescue/utils_test.py | 1 - requirements/requirements.txt | 2 +- 7 files changed, 29 insertions(+), 55 deletions(-) diff --git a/jumanji/environments/swarms/common/common_test.py b/jumanji/environments/swarms/common/common_test.py index a04449df5..fa7559fb4 100644 --- a/jumanji/environments/swarms/common/common_test.py +++ b/jumanji/environments/swarms/common/common_test.py @@ -14,7 +14,6 @@ from typing import List, Tuple -import jax import jax.numpy as jnp import matplotlib import matplotlib.pyplot as plt @@ -56,8 +55,6 @@ def test_velocity_update( actions: List[float], expected: Tuple[float, float], ) -> None: - key = jax.random.PRNGKey(101) - state = types.AgentState( pos=jnp.zeros((1, 2)), heading=jnp.array([heading]), @@ -65,7 +62,7 @@ def test_velocity_update( ) actions = jnp.array([actions]) - new_heading, new_speed = updates.update_velocity(key, params, (actions, state)) + new_heading, new_speed = updates.update_velocity(params, (actions, state)) assert jnp.isclose(new_heading[0], expected[0]) assert jnp.isclose(new_speed[0], expected[1]) @@ -117,8 +114,6 @@ def test_state_update( expected_speed: float, env_size: float, ) -> None: - key = jax.random.PRNGKey(101) - state = types.AgentState( pos=jnp.array([pos]), heading=jnp.array([heading]), @@ -126,7 +121,7 @@ def test_state_update( ) actions = jnp.array([actions]) - new_state = updates.update_state(key, env_size, params, state, actions) + new_state = updates.update_state(env_size, params, state, actions) assert isinstance(new_state, types.AgentState) assert jnp.allclose(new_state.pos, jnp.array([expected_pos])) @@ -137,7 +132,7 @@ def test_state_update( def test_view_reduction() -> None: view_a = jnp.array([-1.0, -1.0, 0.2, 0.2, 0.5]) view_b = jnp.array([-1.0, 0.2, -1.0, 0.5, 0.2]) - result = updates.view_reduction(view_a, view_b) + result = updates.view_reduction_fn(view_a, view_b) assert jnp.allclose(result, jnp.array([-1.0, 0.2, 0.2, 0.2, 0.2])) @@ -170,7 +165,7 @@ def test_view(pos: List[float], view_angle: float, env_size: float, expected: Li ) obs = updates.view( - None, (view_angle, 0.02), state_a, state_b, n_view=5, i_range=0.1, env_size=env_size + (view_angle, 0.02), state_a, state_b, n_view=5, i_range=0.1, env_size=env_size ) assert jnp.allclose(obs, jnp.array(expected)) diff --git a/jumanji/environments/swarms/common/updates.py b/jumanji/environments/swarms/common/updates.py index 2db85f2f1..b43716fea 100644 --- a/jumanji/environments/swarms/common/updates.py +++ b/jumanji/environments/swarms/common/updates.py @@ -24,7 +24,6 @@ @esquilax.transforms.amap def update_velocity( - _key: chex.PRNGKey, params: types.AgentParams, x: Tuple[chex.Array, types.AgentState], ) -> Tuple[chex.Numeric, chex.Numeric]: @@ -32,7 +31,6 @@ def update_velocity( Get the updated agent heading and speeds from actions Args: - _key: Dummy JAX random key. params: Agent parameters. x: Agent rotation and acceleration actions. @@ -72,7 +70,6 @@ def move(pos: chex.Array, heading: chex.Array, speed: chex.Array, env_size: floa def update_state( - key: chex.PRNGKey, env_size: float, params: types.AgentParams, state: types.AgentState, @@ -82,7 +79,6 @@ def update_state( Update the state of a group of agents from a sample of actions Args: - key: Dummy JAX random key. env_size: Size of the environment. params: Agent parameters. state: Current agent states. @@ -93,7 +89,7 @@ def update_state( actions and updating positions. """ actions = jnp.clip(actions, min=-1.0, max=1.0) - headings, speeds = update_velocity(key, params, (actions, state)) + headings, speeds = update_velocity(params, (actions, state)) positions = jax.vmap(move, in_axes=(0, 0, 0, None))(state.pos, headings, speeds, env_size) return types.AgentState( @@ -103,7 +99,7 @@ def update_state( ) -def view_reduction(view_a: chex.Array, view_b: chex.Array) -> chex.Array: +def view_reduction_fn(view_a: chex.Array, view_b: chex.Array) -> chex.Array: """ Binary view reduction function for use in Esquilax spatial transformation. @@ -161,7 +157,6 @@ def angular_width( def view( - _key: chex.PRNGKey, params: Tuple[float, float], viewing_agent: types.AgentState, viewed_agent: types.AgentState, @@ -181,8 +176,6 @@ def view( Currently, this model assumes the viewed agent/objects are circular. Args: - _key: Dummy JAX random key, required by esquilax API, but - not used during the interaction. params: Tuple containing agent view angle and view-radius. viewing_agent: Viewing agent state. viewed_agent: State of agent being viewed. diff --git a/jumanji/environments/swarms/search_and_rescue/env.py b/jumanji/environments/swarms/search_and_rescue/env.py index 6b10fb897..021ecfd41 100644 --- a/jumanji/environments/swarms/search_and_rescue/env.py +++ b/jumanji/environments/swarms/search_and_rescue/env.py @@ -16,6 +16,7 @@ from typing import Optional, Sequence, Tuple import chex +import esquilax import jax import jax.numpy as jnp from esquilax.transforms import spatial @@ -222,25 +223,20 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser state: Updated searcher and target positions and velocities. timestep: Transition timestep with individual agent local observations. """ - # Note: only one new key is needed for the target updates, as all other - # keys are just dummy values required by Esquilax key, target_key = jax.random.split(state.key, num=2) searchers = update_state( - key, self.generator.env_size, self.searcher_params, state.searchers, actions + self.generator.env_size, self.searcher_params, state.searchers, actions ) - targets = self._target_dynamics(target_key, state.targets, self.generator.env_size) # Searchers return an array of flags of any targets they are in range of, # and that have not already been located, result shape here is (n-searcher, n-targets) targets_found = spatial( utils.searcher_detect_targets, - reduction=jnp.logical_or, - default=jnp.zeros((self.generator.num_targets,), dtype=bool), + reduction=esquilax.reductions.logical_or((self.generator.num_targets,)), i_range=self.target_contact_range, dims=self.generator.env_size, )( - key, self.searcher_params.view_angle, searchers, (jnp.arange(self.generator.num_targets), targets), diff --git a/jumanji/environments/swarms/search_and_rescue/observations.py b/jumanji/environments/swarms/search_and_rescue/observations.py index 4abd9c6d6..6d66c9f84 100644 --- a/jumanji/environments/swarms/search_and_rescue/observations.py +++ b/jumanji/environments/swarms/search_and_rescue/observations.py @@ -16,14 +16,21 @@ from typing import Tuple import chex +import esquilax import jax.numpy as jnp -from esquilax.transforms import spatial from jumanji.environments.swarms.common.types import AgentState -from jumanji.environments.swarms.common.updates import angular_width, view, view_reduction +from jumanji.environments.swarms.common.updates import angular_width, view, view_reduction_fn from jumanji.environments.swarms.search_and_rescue.types import State, TargetState +def view_reduction(view_shape: Tuple[int, ...]) -> esquilax.reductions.Reduction: + return esquilax.reductions.Reduction( + fn=view_reduction_fn, + id=-jnp.ones(view_shape), + ) + + class ObservationFn(abc.ABC): def __init__( self, @@ -109,15 +116,13 @@ def __call__(self, state: State) -> chex.Array: Array of individual agent views of shape (n-agents, 1, n-vision). """ - searcher_views = spatial( + searcher_views = esquilax.transforms.spatial( view, - reduction=view_reduction, - default=-jnp.ones((self.num_vision,)), + reduction=view_reduction((self.num_vision,)), include_self=False, i_range=self.vision_range, dims=self.env_size, )( - state.key, (self.view_angle, self.agent_radius), state.searchers, state.searchers, @@ -130,7 +135,6 @@ def __call__(self, state: State) -> chex.Array: def found_target_view( - _key: chex.PRNGKey, params: Tuple[float, float], searcher: AgentState, target: TargetState, @@ -146,7 +150,6 @@ def found_target_view( by Esquilax. Args: - _key: Dummy random key (required by Esquilax). params: View angle and target visual radius. searcher: Searcher agent state target: Target state @@ -224,15 +227,13 @@ def __call__(self, state: State) -> chex.Array: (n-agents, 2, n-vision). Other agents are shown in channel 0, and located targets 1. """ - searcher_views = spatial( + searcher_views = esquilax.transforms.spatial( view, - reduction=view_reduction, - default=-jnp.ones((self.num_vision,)), + reduction=view_reduction((self.num_vision,)), include_self=False, i_range=self.vision_range, dims=self.env_size, )( - state.key, (self.view_angle, self.agent_radius), state.searchers, state.searchers, @@ -241,15 +242,13 @@ def __call__(self, state: State) -> chex.Array: i_range=self.vision_range, env_size=self.env_size, ) - target_views = spatial( + target_views = esquilax.transforms.spatial( found_target_view, - reduction=view_reduction, - default=-jnp.ones((self.num_vision,)), + reduction=view_reduction((self.num_vision,)), include_self=False, i_range=self.vision_range, dims=self.env_size, )( - state.key, (self.view_angle, self.agent_radius), state.searchers, state.targets, @@ -263,7 +262,6 @@ def __call__(self, state: State) -> chex.Array: def all_target_view( - _key: chex.PRNGKey, params: Tuple[float, float], searcher: AgentState, target: TargetState, @@ -279,7 +277,6 @@ def all_target_view( by Esquilax. Args: - _key: Dummy random key (required by Esquilax). params: View angle and target visual radius. searcher: Searcher agent state target: Target state @@ -361,15 +358,13 @@ def __call__(self, state: State) -> chex.Array: in channel 0, located targets 1, and un-located targets at index 2. """ - searcher_views = spatial( + searcher_views = esquilax.transforms.spatial( view, - reduction=view_reduction, - default=-jnp.ones((self.num_vision,)), + reduction=view_reduction((self.num_vision,)), include_self=False, i_range=self.vision_range, dims=self.env_size, )( - state.key, (self.view_angle, self.agent_radius), state.searchers, state.searchers, @@ -378,15 +373,13 @@ def __call__(self, state: State) -> chex.Array: i_range=self.vision_range, env_size=self.env_size, ) - target_views = spatial( + target_views = esquilax.transforms.spatial( all_target_view, - reduction=view_reduction, - default=-jnp.ones((2, self.num_vision)), + reduction=view_reduction((2, self.num_vision)), include_self=False, i_range=self.vision_range, dims=self.env_size, )( - state.key, (self.view_angle, self.agent_radius), state.searchers, state.targets, diff --git a/jumanji/environments/swarms/search_and_rescue/utils.py b/jumanji/environments/swarms/search_and_rescue/utils.py index e45f82343..5cf577048 100644 --- a/jumanji/environments/swarms/search_and_rescue/utils.py +++ b/jumanji/environments/swarms/search_and_rescue/utils.py @@ -50,7 +50,6 @@ def _check_target_in_view( def searcher_detect_targets( - _key: chex.PRNGKey, searcher_view_angle: float, searcher: AgentState, target: Tuple[chex.Array, TargetState], @@ -65,7 +64,6 @@ def searcher_detect_targets( searchers view cone, and has not already been detected. Args: - _key: Dummy random key (required by Esquilax). searcher_view_angle: View angle of searching agents representing a fraction of pi from the agents heading. searcher: State of the searching agent (i.e. the agent diff --git a/jumanji/environments/swarms/search_and_rescue/utils_test.py b/jumanji/environments/swarms/search_and_rescue/utils_test.py index be909074e..af4f11a37 100644 --- a/jumanji/environments/swarms/search_and_rescue/utils_test.py +++ b/jumanji/environments/swarms/search_and_rescue/utils_test.py @@ -83,7 +83,6 @@ def test_target_found( ) found = jax.jit(partial(searcher_detect_targets, env_size=env_size, n_targets=1))( - None, view_angle, searcher, (jnp.arange(1), target), diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 34a8c3253..b6bb622bd 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -1,6 +1,6 @@ chex>=0.1.3 dm-env>=1.5 -esquilax>=1.0.3 +esquilax>=2.0.0 gymnasium>=1.0 huggingface-hub jax>=0.2.26 From 04fe710a4d88f1adc87ded49c636ddd64976662a Mon Sep 17 00:00:00 2001 From: zombie-einstein <13398815+zombie-einstein@users.noreply.github.com> Date: Sun, 12 Jan 2025 23:02:16 +0000 Subject: [PATCH 18/19] docs: Review docstrings and docs (#15) * Review docstrings * Add scaled individual reward function * Doc tweaks * Fix upper-case pi symbols --- docs/environments/search_and_rescue.md | 21 +++++--- jumanji/environments/swarms/common/updates.py | 4 +- .../swarms/search_and_rescue/dynamics.py | 7 +-- .../swarms/search_and_rescue/env.py | 32 +++++++----- .../swarms/search_and_rescue/observations.py | 34 ++++++------- .../swarms/search_and_rescue/reward.py | 51 ++++++++++++++----- .../swarms/search_and_rescue/reward_test.py | 18 ++++++- .../swarms/search_and_rescue/types.py | 11 ++-- .../swarms/search_and_rescue/utils.py | 12 ++--- .../swarms/search_and_rescue/viewer.py | 13 +++-- 10 files changed, 130 insertions(+), 73 deletions(-) diff --git a/docs/environments/search_and_rescue.md b/docs/environments/search_and_rescue.md index d7ba33a89..a707a8553 100644 --- a/docs/environments/search_and_rescue.md +++ b/docs/environments/search_and_rescue.md @@ -12,9 +12,9 @@ is updated in the following sequence: - The velocity of searching agents are updated, and consequently their positions. - The positions of targets are updated. -- Targets within detection range and an agents view cone are marked as found. +- Targets within detection range, and within an agents view cone are marked as found. - Agents are rewarded for locating previously unfound targets. -- Local views of the environment are generated for each search agent. +- Local views of the environment are generated for each searching agent. The agents are allotted a fixed number of steps to locate the targets. The search space is a uniform square space, wrapped at the boundaries. @@ -22,7 +22,8 @@ space is a uniform square space, wrapped at the boundaries. Many aspects of the environment can be customised: - Agent observations can include targets as well as other searcher agents. -- Rewards can be shared by agents, or can be treated completely individually for individual agents. +- Rewards can be shared by agents, or can be treated completely individually for individual + agents and can be scaled by time-step. - Target dynamics can be customised to model various search scenarios. ## Observations @@ -38,13 +39,14 @@ Many aspects of the environment can be customised: [-1.0, -1.0, 0.5, -1.0, -1.0] ``` - where `-1.0` indicates there is no agents along that ray, and `0.5` is the normalised + where `-1.0` indicates there are no agents along that ray, and `0.5` is the normalised distance to the other agent. Channels in the segmented view are used to differentiate between different agents/targets and can be customised. By default, the view has three - channels representing other agents, found targets, and unfound targets. + channels representing other agents, found targets, and unlocated targets. - `targets_remaining`: float in the range `[0, 1]`. The normalised number of targets remaining to be detected (i.e. 1.0 when no targets have been found). -- `Step`: int in the range `[0, time_limit]`. The current simulation step. +- `step`: int in the range `[0, time_limit]`. The current simulation step. +- `positions`: jax array (float) of shape `(num_searchers, 2)`. Agent coordinates. ## Actions @@ -65,11 +67,14 @@ and speed speed = speed + max_acceleration * action[1] ``` -Once applied, agent speeds are clipped to velocities within a fixed range of speeds. +Once applied, agent speeds are clipped to velocities within a fixed range of speeds given +by the `min_speed` and `max_speed` parameters. ## Rewards Jax array (float) of `(num_searchers,)`. Rewards are generated for each agent individually. + Agents are rewarded +1 for locating a target that has not already been detected. It is possible for multiple agents to detect a target inside a step, as such rewards can either be shared -by the locating agents, or each agent can get the full reward. +by the locating agents, or each individual agent can get the full reward. Rewards provided can +also be scaled by simulation step to encourage agents to develop efficient search patterns. diff --git a/jumanji/environments/swarms/common/updates.py b/jumanji/environments/swarms/common/updates.py index b43716fea..25f99e1e5 100644 --- a/jumanji/environments/swarms/common/updates.py +++ b/jumanji/environments/swarms/common/updates.py @@ -63,7 +63,7 @@ def move(pos: chex.Array, heading: chex.Array, speed: chex.Array, env_size: floa env_size: Size of the environment. Returns: - jax array (float32): Updated agent position. + jax array (float32): Updated agent positions. """ d_pos = jnp.array([speed * jnp.cos(heading), speed * jnp.sin(heading)]) return (pos + d_pos) % env_size @@ -170,7 +170,7 @@ def view( Simple view model where the agents view angle is subdivided into an array of values representing the distance from - the agent along a rays from the agent, with rays evenly distributed + the agent along rays from the agent, with rays evenly distributed across the agents field of view. The limit of vision is set at 1.0. The default value if no object is within range is -1.0. Currently, this model assumes the viewed agent/objects are circular. diff --git a/jumanji/environments/swarms/search_and_rescue/dynamics.py b/jumanji/environments/swarms/search_and_rescue/dynamics.py index e3d9b47ef..1a6897769 100644 --- a/jumanji/environments/swarms/search_and_rescue/dynamics.py +++ b/jumanji/environments/swarms/search_and_rescue/dynamics.py @@ -25,8 +25,9 @@ class TargetDynamics(abc.ABC): def __call__(self, key: chex.PRNGKey, targets: TargetState, env_size: float) -> TargetState: """Interface for target state update function. - NOTE: Target positions should be bound to environment - area (generally wrapped around at the boundaries). + NOTE: Target positions should be inside the bounds + of the environment. Out-of-bound co-ordinates can + lead to unexpected behaviour. Args: key: Random key. @@ -55,7 +56,7 @@ def __call__(self, key: chex.PRNGKey, targets: TargetState, env_size: float) -> """Update target state. Args: - key: random key. + key: Random key. targets: Current target states. env_size: Environment size. diff --git a/jumanji/environments/swarms/search_and_rescue/env.py b/jumanji/environments/swarms/search_and_rescue/env.py index 021ecfd41..ad7c7789a 100644 --- a/jumanji/environments/swarms/search_and_rescue/env.py +++ b/jumanji/environments/swarms/search_and_rescue/env.py @@ -61,17 +61,19 @@ class SearchAndRescue(Environment): Each entry in the view indicates the distant to another agent/target along a ray from the agent, and is -1.0 if nothing is in range along the ray. The view model can be customised using an `ObservationFn` implementation, e.g. - the view can include all agents and targets, or just other agents. + the view can include agents and all targets, agents and found targets,or + just other agents. targets_remaining: (float) Number of targets remaining to be found from the total scaled to the range [0, 1] (i.e. a value of 1.0 indicates all the targets are still to be found). step: (int) current simulation step. + positions: jax array (float) of shape (num_searchers, 2) search agent positions. - action: jax array (float) of shape (num_searchers, 2) Array of individual agent actions. Each agents actions rotate and accelerate/decelerate the agent as [rotation, acceleration] on the range [-1, 1]. These values are then scaled to update agent velocities within - given parameters. + given parameters (i.e. a value of -+1 is the maximum acceleration/rotation). - reward: jax array (float) of shape (num_searchers,) Arrays of individual agent rewards. A reward of +1 is granted when an agent @@ -84,7 +86,7 @@ class SearchAndRescue(Environment): - searchers: `AgentState` - pos: jax array (float) of shape (num_searchers, 2) in the range [0, env_size]. - heading: jax array (float) of shape (num_searcher,) in - the range [0, 2pi]. + the range [0, 2π]. - speed: jax array (float) of shape (num_searchers,) in the range [min_speed, max_speed]. - targets: `TargetState` @@ -115,7 +117,7 @@ def __init__( searcher_max_accelerate: float = 0.005, searcher_min_speed: float = 0.01, searcher_max_speed: float = 0.02, - searcher_view_angle: float = 0.75, + searcher_view_angle: float = 0.5, time_limit: int = 400, viewer: Optional[Viewer[State]] = None, target_dynamics: Optional[TargetDynamics] = None, @@ -129,19 +131,18 @@ def __init__( target_contact_range: Range at which a searchers will 'find' a target. searcher_max_rotate: Maximum rotation searcher agents can turn within a step. Should be a value from [0,1] - representing a fraction of pi radians. + representing a fraction of π-radians. searcher_max_accelerate: Magnitude of the maximum acceleration/deceleration a searcher agent can apply within a step. searcher_min_speed: Minimum speed a searcher agent can move at. searcher_max_speed: Maximum speed a searcher agent can move at. searcher_view_angle: Searcher agent local view angle. Should be - a value from [0,1] representing a fraction of pi radians. + a value from [0,1] representing a fraction of π-radians. The view cone of an agent goes from +- of the view angle relative to its heading, e.g. 0.5 would mean searchers have a 90° view angle in total. time_limit: Maximum number of environment steps allowed for search. viewer: `Viewer` used for rendering. Defaults to `SearchAndRescueViewer`. - target_dynamics: target_dynamics: Target object dynamics model, implemented as a `TargetDynamics` interface. Defaults to `RandomWalk`. generator: Initial state `Generator` instance. Defaults to `RandomGenerator` @@ -150,7 +151,7 @@ def __init__( agents share rewards if they locate a target simultaneously. observation: Agent observation view generation function. Defaults to `AgentAndAllTargetObservationFn` where all targets (found and unfound) - and other ogents are included in the generated view. + and other searching agents are included in the generated view. """ self.target_contact_range = target_contact_range @@ -164,14 +165,14 @@ def __init__( ) self.time_limit = time_limit self._target_dynamics = target_dynamics or RandomWalk(0.001) - self.generator = generator or RandomGenerator(num_targets=100, num_searchers=2) + self.generator = generator or RandomGenerator(num_targets=50, num_searchers=2) self._viewer = viewer or SearchAndRescueViewer() self._reward_fn = reward_fn or SharedRewardFn() self._observation = observation or AgentAndAllTargetObservationFn( num_vision=64, - vision_range=0.1, + vision_range=0.25, view_angle=searcher_view_angle, - agent_radius=0.01, + agent_radius=0.02, env_size=self.generator.env_size, ) super().__init__() @@ -182,7 +183,12 @@ def __repr__(self) -> str: "Search & rescue multi-agent environment:", f" - num searchers: {self.generator.num_searchers}", f" - num targets: {self.generator.num_targets}", + f" - max searcher rotation: {self.searcher_params.max_rotate}", + f" - max searcher acceleration: {self.searcher_params.max_accelerate}", + f" - searcher min speed: {self.searcher_params.min_speed}", + f" - searcher max speed: {self.searcher_params.max_speed}", f" - search vision range: {self._observation.vision_range}", + f" - search view angle: {self._observation.view_angle}", f" - target contact range: {self.target_contact_range}", f" - num vision: {self._observation.num_vision}", f" - agent radius: {self._observation.agent_radius}", @@ -217,7 +223,7 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser Args: state: Environment state. - actions: Arrays of searcher steering actions. + actions: 2d array of searcher steering actions. Returns: state: Updated searcher and target positions and velocities. @@ -360,7 +366,7 @@ def render(self, state: State) -> None: """Render a frame of the environment for a given state using matplotlib. Args: - state: State object containing the current state of the environment. + state: State object. """ self._viewer.render(state) diff --git a/jumanji/environments/swarms/search_and_rescue/observations.py b/jumanji/environments/swarms/search_and_rescue/observations.py index 6d66c9f84..6a7e587c0 100644 --- a/jumanji/environments/swarms/search_and_rescue/observations.py +++ b/jumanji/environments/swarms/search_and_rescue/observations.py @@ -53,7 +53,7 @@ def __init__( num_channels: Number of channels in agent view. num_vision: Size of vision array. vision_range: Vision range. - view_angle: Agent view angle (as a fraction of pi). + view_angle: Agent view angle (as a fraction of π). agent_radius: Agent/target visual radius. env_size: Environment size. """ @@ -92,7 +92,7 @@ def __init__( Args: num_vision: Size of vision array. vision_range: Vision range. - view_angle: Agent view angle (as a fraction of pi). + view_angle: Agent view angle (as a fraction of π). agent_radius: Agent/target visual radius. env_size: Environment size. """ @@ -146,14 +146,14 @@ def found_target_view( """ Return view of a target, dependent on target status. - This function is intended to be mapped over agents-targets - by Esquilax. + This function is intended to be mapped over agents-target + pairs by Esquilax. Args: params: View angle and target visual radius. searcher: Searcher agent state target: Target state - n_view: Number of value sin view array. + n_view: Number of values in view array. i_range: Vision range env_size: Environment size @@ -190,7 +190,7 @@ def __init__( env_size: float, ) -> None: """ - Vision model that contains other agents, and found targets. + Vision model that contains other agents and found targets. Searchers and targets are visualised as individual channels. Targets are only included if they have been located already. @@ -198,7 +198,7 @@ def __init__( Args: num_vision: Size of vision array. vision_range: Vision range. - view_angle: Agent view angle (as a fraction of pi). + view_angle: Agent view angle (as a fraction of π). agent_radius: Agent/target visual radius. env_size: Environment size. """ @@ -225,7 +225,7 @@ def __call__(self, state: State) -> chex.Array: Returns: Array of individual agent views of shape (n-agents, 2, n-vision). Other agents are shown - in channel 0, and located targets 1. + in channel 0, and located targets in channel 1. """ searcher_views = esquilax.transforms.spatial( view, @@ -273,16 +273,16 @@ def all_target_view( """ Return view of a target, dependent on target status. - This function is intended to be mapped over agents-targets - by Esquilax. + This function is intended to be mapped over agents target + pairs by Esquilax. Args: params: View angle and target visual radius. - searcher: Searcher agent state - target: Target state + searcher: Searcher agent state. + target: Target state. n_view: Number of value sin view array. - i_range: Vision range - env_size: Environment size + i_range: Vision range. + env_size: Environment size. Returns: Segmented agent view of target. @@ -328,7 +328,7 @@ def __init__( Args: num_vision: Size of vision array. vision_range: Vision range. - view_angle: Agent view angle (as a fraction of pi). + view_angle: Agent view angle (as a fraction of π). agent_radius: Agent/target visual radius. env_size: Environment size. """ @@ -355,8 +355,8 @@ def __call__(self, state: State) -> chex.Array: Returns: Array of individual agent views of shape (n-agents, 3, n-vision). Other agents are shown - in channel 0, located targets 1, and un-located - targets at index 2. + in channel 0, located targets channel 1, and un-located + targets in channel 2. """ searcher_views = esquilax.transforms.spatial( view, diff --git a/jumanji/environments/swarms/search_and_rescue/reward.py b/jumanji/environments/swarms/search_and_rescue/reward.py index 40c38b9d4..756fec884 100644 --- a/jumanji/environments/swarms/search_and_rescue/reward.py +++ b/jumanji/environments/swarms/search_and_rescue/reward.py @@ -26,13 +26,25 @@ def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> che """The reward function used in the `SearchAndRescue` environment. Args: - found_targets: Array of boolean flags indicating + found_targets: Array of boolean flags indicating if an + agent has found a target. Returns: Individual reward for each agent. """ +def _normalise_rewards(rewards: chex.Array) -> chex.Array: + norms = jnp.sum(rewards, axis=0)[jnp.newaxis] + rewards = jnp.where(norms > 0, rewards / norms, rewards) + return rewards + + +def _scale_rewards(rewards: chex.Array, step: int, time_limit: int) -> chex.Array: + scale = (time_limit - step) / time_limit + return scale * rewards + + class SharedRewardFn(RewardFn): """ Calculate per agent rewards from detected targets @@ -43,12 +55,29 @@ class SharedRewardFn(RewardFn): def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array: rewards = found_targets.astype(float) - norms = jnp.sum(rewards, axis=0)[jnp.newaxis] - rewards = jnp.where(norms > 0, rewards / norms, rewards) + rewards = _normalise_rewards(rewards) rewards = jnp.sum(rewards, axis=1) return rewards +class SharedScaledRewardFn(RewardFn): + """ + Calculate per agent rewards from detected targets and scale by timestep + + Targets detected by multiple agents share rewards. Agents + can receive rewards for detecting multiple targets. + Rewards are linearly scaled by the current time step such that + rewards received are 0 at the final step. + """ + + def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array: + rewards = found_targets.astype(float) + rewards = _normalise_rewards(rewards) + rewards = jnp.sum(rewards, axis=1) + rewards = _scale_rewards(rewards, step, time_limit) + return rewards + + class IndividualRewardFn(RewardFn): """ Calculate per agent rewards from detected targets @@ -63,20 +92,18 @@ def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> che return rewards -class SharedScaledRewardFn(RewardFn): +class IndividualScaledRewardFn(RewardFn): """ - Calculate per agent rewards from detected targets + Calculate per agent rewards from detected targets and scale by timestep - Targets detected by multiple agents share rewards. Agents - can receive rewards for detecting multiple targets. + Each agent that detects a target receives a +1 reward + even if a target is detected by multiple agents. Rewards are linearly scaled by the current time step such that - rewards are 0 at the final step. + rewards received are 0 at the final step. """ def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array: rewards = found_targets.astype(float) - norms = jnp.sum(rewards, axis=0)[jnp.newaxis] - rewards = jnp.where(norms > 0, rewards / norms, rewards) rewards = jnp.sum(rewards, axis=1) - scale = (time_limit - step) / time_limit - return scale * rewards + rewards = _scale_rewards(rewards, step, time_limit) + return rewards diff --git a/jumanji/environments/swarms/search_and_rescue/reward_test.py b/jumanji/environments/swarms/search_and_rescue/reward_test.py index d351fc259..bd49f69e8 100644 --- a/jumanji/environments/swarms/search_and_rescue/reward_test.py +++ b/jumanji/environments/swarms/search_and_rescue/reward_test.py @@ -39,7 +39,7 @@ def test_individual_rewards(target_states: chex.Array) -> None: assert jnp.allclose(individual_rewards, jnp.array([2.0, 1.0])) -def test_scaled_rewards(target_states: chex.Array) -> None: +def test_shared_scaled_rewards(target_states: chex.Array) -> None: reward_fn = reward.SharedScaledRewardFn() shared_scaled_rewards = reward_fn(target_states, 0, 10) @@ -53,3 +53,19 @@ def test_scaled_rewards(target_states: chex.Array) -> None: assert shared_scaled_rewards.shape == (2,) assert shared_scaled_rewards.dtype == jnp.float32 assert jnp.allclose(shared_scaled_rewards, jnp.array([0.0, 0.0])) + + +def test_individual_scaled_rewards(target_states: chex.Array) -> None: + reward_fn = reward.IndividualScaledRewardFn() + + individual_scaled_rewards = reward_fn(target_states, 0, 10) + + assert individual_scaled_rewards.shape == (2,) + assert individual_scaled_rewards.dtype == jnp.float32 + assert jnp.allclose(individual_scaled_rewards, jnp.array([2.0, 1.0])) + + individual_scaled_rewards = reward_fn(target_states, 10, 10) + + assert individual_scaled_rewards.shape == (2,) + assert individual_scaled_rewards.dtype == jnp.float32 + assert jnp.allclose(individual_scaled_rewards, jnp.array([0.0, 0.0])) diff --git a/jumanji/environments/swarms/search_and_rescue/types.py b/jumanji/environments/swarms/search_and_rescue/types.py index 919c6c55f..3fba76438 100644 --- a/jumanji/environments/swarms/search_and_rescue/types.py +++ b/jumanji/environments/swarms/search_and_rescue/types.py @@ -26,10 +26,10 @@ @dataclass class TargetState: """ - The state for the rescue targets. + The state of the rescue targets. - pos: 2d position of the target agents - velocity: 2d velocity of the target agents + pos: 2d position of the target agents. + velocity: 2d velocity of the target agents. found: Boolean flag indicating if the target has been located by a searcher. """ @@ -57,7 +57,7 @@ class State: class Observation(NamedTuple): """ Individual observations for searching agents and information - on number of remaining time and targets to be found. + on number of remaining steps and ratio of targets to be found. Each agent generates an independent observation, an array of values representing the distance along a ray from the agent to @@ -65,6 +65,9 @@ class Observation(NamedTuple): (with `num_vision` rays evenly distributed over the agents field of vision). + The co-ordinates of each agent are also included in the + observation for debug and use in global observations. + For example if an agent sees another agent straight ahead and `num_vision = 5` then the observation array could be diff --git a/jumanji/environments/swarms/search_and_rescue/utils.py b/jumanji/environments/swarms/search_and_rescue/utils.py index 5cf577048..ed52b33c5 100644 --- a/jumanji/environments/swarms/search_and_rescue/utils.py +++ b/jumanji/environments/swarms/search_and_rescue/utils.py @@ -33,11 +33,11 @@ def _check_target_in_view( Check if a target is inside the view-cone of a searcher. Args: - searcher_pos: Searcher position - target_pos: Target position - searcher_heading: Searcher heading angle - searcher_view_angle: Searcher view angle - env_size: Size of the environment + searcher_pos: Searcher position. + target_pos: Target position. + searcher_heading: Searcher heading angle. + searcher_view_angle: Searcher view angle. + env_size: Size of the environment. Returns: bool: Flag indicating if a target is within view. @@ -65,7 +65,7 @@ def searcher_detect_targets( Args: searcher_view_angle: View angle of searching agents - representing a fraction of pi from the agents heading. + representing a fraction of π from the agents heading. searcher: State of the searching agent (i.e. the agent position and heading) target: Index and State of the target (i.e. its position and diff --git a/jumanji/environments/swarms/search_and_rescue/viewer.py b/jumanji/environments/swarms/search_and_rescue/viewer.py index 9a48103f7..81655820f 100644 --- a/jumanji/environments/swarms/search_and_rescue/viewer.py +++ b/jumanji/environments/swarms/search_and_rescue/viewer.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, Sequence, Tuple +from typing import Optional, Sequence, Tuple import jax.numpy as jnp import matplotlib.animation import matplotlib.pyplot as plt import numpy as np +from matplotlib.artist import Artist from matplotlib.layout_engine import TightLayoutEngine import jumanji @@ -40,8 +41,8 @@ def __init__( """Viewer for the `SearchAndRescue` environment. Args: - figure_name: the window name to be used when initialising the window. - figure_size: tuple (height, width) of the matplotlib figure window. + figure_name: The window name to be used when initialising the window. + figure_size: Tuple (height, width) of the matplotlib figure window. searcher_color: Color of searcher agent markers (arrows). target_found_color: Color of target markers when they have been found. target_lost_color: Color of target markers when they are still to be found. @@ -59,7 +60,6 @@ def render(self, state: State) -> None: Args: state: State object containing the current dynamics of the environment. - """ self._clear_display() fig, ax = self._get_fig_ax() @@ -90,8 +90,7 @@ def animate( states[0].targets.pos[:, 0], states[0].targets.pos[:, 1], marker="o" ) - def make_frame(state: State) -> Any: - # Rather than redraw just update the quivers and scatter properties + def make_frame(state: State) -> Tuple[Artist, Artist]: searcher_quiver.set_offsets(state.searchers.pos) searcher_quiver.set_UVC( jnp.cos(state.searchers.heading), jnp.sin(state.searchers.heading) @@ -99,7 +98,7 @@ def make_frame(state: State) -> Any: target_colors = self.target_colors[state.targets.found.astype(jnp.int32)] target_scatter.set_offsets(state.targets.pos) target_scatter.set_color(target_colors) - return ((searcher_quiver, target_scatter),) + return searcher_quiver, target_scatter matplotlib.rc("animation", html="jshtml") self._animation = matplotlib.animation.FuncAnimation( From ac3f811e023108e9c3a538fc72773dcd32f5910c Mon Sep 17 00:00:00 2001 From: zombie-einstein <13398815+zombie-einstein@users.noreply.github.com> Date: Sun, 12 Jan 2025 23:19:06 +0000 Subject: [PATCH 19/19] fix: Remove enum annotations --- jumanji/environments/routing/snake/types.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jumanji/environments/routing/snake/types.py b/jumanji/environments/routing/snake/types.py index 7617240e8..58b9b4d7b 100644 --- a/jumanji/environments/routing/snake/types.py +++ b/jumanji/environments/routing/snake/types.py @@ -40,10 +40,10 @@ def __add__(self, other: Position) -> Position: # type: ignore[override] class Actions(IntEnum): - UP: int = 0 - RIGHT: int = 1 - DOWN: int = 2 - LEFT: int = 3 + UP = 0 + RIGHT = 1 + DOWN = 2 + LEFT = 3 @dataclass