Skip to content

Commit

Permalink
feat: Scaled rewards and target velocities (#10)
Browse files Browse the repository at this point in the history
* Use channels view parameters

* Rename parameters

* Include step-number in observation

* Add velocity field to targets

* Add time scaled reward function
  • Loading branch information
zombie-einstein authored Dec 11, 2024
1 parent 13ffb84 commit 9e8ac5c
Show file tree
Hide file tree
Showing 11 changed files with 115 additions and 51 deletions.
2 changes: 1 addition & 1 deletion jumanji/environments/swarms/search_and_rescue/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def env() -> SearchAndRescue:
searcher_min_speed=0.01,
searcher_max_speed=0.05,
searcher_view_angle=0.5,
max_steps=25,
time_limit=10,
)


Expand Down
19 changes: 11 additions & 8 deletions jumanji/environments/swarms/search_and_rescue/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,20 @@
import chex
import jax

from jumanji.environments.swarms.search_and_rescue.types import TargetState


class TargetDynamics(abc.ABC):
@abc.abstractmethod
def __call__(self, key: chex.PRNGKey, target_pos: chex.Array) -> chex.Array:
def __call__(self, key: chex.PRNGKey, targets: TargetState, env_size: float) -> TargetState:
"""Interface for target position update function.
Args:
key: random key.
target_pos: Current target positions.
targets: Current target states.
Returns:
Updated target positions.
Updated target states.
"""


Expand All @@ -46,16 +48,17 @@ def __init__(self, step_size: float):
"""
self.step_size = step_size

def __call__(self, key: chex.PRNGKey, target_pos: chex.Array) -> chex.Array:
def __call__(self, key: chex.PRNGKey, targets: TargetState, env_size: float) -> TargetState:
"""Update target positions.
Args:
key: random key.
target_pos: Current target positions.
targets: Current target states.
Returns:
Updated target positions.
Updated target states.
"""
d_pos = jax.random.uniform(key, target_pos.shape)
d_pos = jax.random.uniform(key, targets.pos.shape)
d_pos = self.step_size * 2.0 * (d_pos - 0.5)
return target_pos + d_pos
pos = (targets.pos + d_pos) % env_size
return TargetState(pos=pos, vel=targets.vel, found=targets.found)
43 changes: 26 additions & 17 deletions jumanji/environments/swarms/search_and_rescue/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
searcher_min_speed: float = 0.01,
searcher_max_speed: float = 0.02,
searcher_view_angle: float = 0.75,
max_steps: int = 400,
time_limit: int = 400,
viewer: Optional[Viewer[State]] = None,
target_dynamics: Optional[TargetDynamics] = None,
generator: Optional[Generator] = None,
Expand All @@ -136,7 +136,7 @@ def __init__(
The view cone of an agent goes from +- of the view angle
relative to its heading, e.g. 0.5 would mean searchers have a
90° view angle in total.
max_steps: Maximum number of environment steps allowed for search.
time_limit: Maximum number of environment steps allowed for search.
viewer: `Viewer` used for rendering. Defaults to `SearchAndRescueViewer`.
target_dynamics:
target_dynamics: Target object dynamics model, implemented as a
Expand All @@ -156,7 +156,7 @@ def __init__(
max_speed=searcher_max_speed,
view_angle=searcher_view_angle,
)
self.max_steps = max_steps
self.time_limit = time_limit
self._target_dynamics = target_dynamics or RandomWalk(0.001)
self.generator = generator or RandomGenerator(num_targets=100, num_searchers=2)
self._viewer = viewer or SearchAndRescueViewer()
Expand All @@ -180,7 +180,7 @@ def __repr__(self) -> str:
f" - target contact range: {self.target_contact_range}",
f" - num vision: {self._observation.num_vision}",
f" - agent radius: {self._observation.agent_radius}",
f" - max steps: {self.max_steps},"
f" - time limit: {self.time_limit},"
f" - env size: {self.generator.env_size}"
f" - target dynamics: {self._target_dynamics.__class__.__name__}",
f" - generator: {self.generator.__class__.__name__}",
Expand Down Expand Up @@ -223,11 +223,12 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser
searchers = update_state(
key, self.generator.env_size, self.searcher_params, state.searchers, actions
)
# Ensure target positions are wrapped
target_pos = self._target_dynamics(target_key, state.targets.pos) % self.generator.env_size

targets = self._target_dynamics(target_key, state.targets, self.generator.env_size)

# Searchers return an array of flags of any targets they are in range of,
# and that have not already been located, result shape here is (n-searcher, n-targets)
n_targets = target_pos.shape[0]
n_targets = targets.pos.shape[0]
targets_found = spatial(
utils.searcher_detect_targets,
reduction=jnp.logical_or,
Expand All @@ -238,29 +239,29 @@ def step(self, state: State, actions: chex.Array) -> Tuple[State, TimeStep[Obser
key,
self.searcher_params.view_angle,
searchers,
(jnp.arange(n_targets), state.targets),
(jnp.arange(n_targets), targets),
pos=searchers.pos,
pos_b=target_pos,
pos_b=targets.pos,
env_size=self.generator.env_size,
n_targets=n_targets,
)

rewards = self._reward_fn(targets_found)
rewards = self._reward_fn(targets_found, state.step, self.time_limit)

targets_found = jnp.any(targets_found, axis=0)
# Targets need to remain found if they already have been
targets_found = jnp.logical_or(targets_found, state.targets.found)

state = State(
searchers=searchers,
targets=TargetState(pos=target_pos, found=targets_found),
targets=TargetState(pos=targets.pos, vel=targets.vel, found=targets_found),
key=key,
step=state.step + 1,
)
observation = self._state_to_observation(state)
observation = jax.lax.stop_gradient(observation)
timestep = jax.lax.cond(
jnp.logical_or(state.step >= self.max_steps, jnp.all(targets_found)),
jnp.logical_or(state.step >= self.time_limit, jnp.all(targets_found)),
termination,
transition,
rewards,
Expand All @@ -273,9 +274,13 @@ def _state_to_observation(self, state: State) -> Observation:
return Observation(
searcher_views=searcher_views,
targets_remaining=1.0 - jnp.sum(state.targets.found) / self.generator.num_targets,
time_remaining=1.0 - state.step / (self.max_steps + 1),
step=state.step,
)

@cached_property
def num_agents(self) -> int:
return self.generator.num_searchers

@cached_property
def observation_spec(self) -> specs.Spec[Observation]:
"""Returns the observation spec.
Expand All @@ -287,7 +292,11 @@ def observation_spec(self) -> specs.Spec[Observation]:
observation_spec: Search-and-rescue observation spec
"""
searcher_views = specs.BoundedArray(
shape=(self.generator.num_searchers, *self._observation.view_shape),
shape=(
self.generator.num_searchers,
self._observation.num_channels,
self._observation.num_vision,
),
minimum=-1.0,
maximum=1.0,
dtype=float,
Expand All @@ -298,10 +307,10 @@ def observation_spec(self) -> specs.Spec[Observation]:
"ObservationSpec",
searcher_views=searcher_views,
targets_remaining=specs.BoundedArray(
shape=(), minimum=0.0, maximum=1.0, name="targets_remaining", dtype=float
shape=(), minimum=0.0, maximum=1.0, name="targets_remaining", dtype=jnp.float32
),
time_remaining=specs.BoundedArray(
shape=(), minimum=0.0, maximum=1.0, name="time_remaining", dtype=float
step=specs.BoundedArray(
shape=(), minimum=0, maximum=self.time_limit, name="step", dtype=jnp.int32
),
)

Expand Down
16 changes: 11 additions & 5 deletions jumanji/environments/swarms/search_and_rescue/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def test_env_init(env: SearchAndRescue, key: chex.PRNGKey) -> None:
assert isinstance(timestep.observation, Observation)
assert timestep.observation.searcher_views.shape == (
env.generator.num_searchers,
*env._observation.view_shape,
env._observation.num_channels,
env._observation.num_vision,
)
assert timestep.step_type == StepType.FIRST

Expand All @@ -69,8 +70,9 @@ def test_env_step(env: SearchAndRescue, key: chex.PRNGKey, env_size: float) -> N
check states (i.e. positions, heading, speeds) all fall
inside expected ranges.
"""
n_steps = 22
n_steps = env.time_limit
env.generator.env_size = env_size
env.time_limit = 22

def step(
carry: Tuple[chex.PRNGKey, State], _: None
Expand Down Expand Up @@ -108,7 +110,7 @@ def step(

def test_env_does_not_smoke(env: SearchAndRescue) -> None:
"""Test that we can run an episode without any errors."""
env.max_steps = 10
env.time_limit = 10

def select_action(action_key: chex.PRNGKey, _state: Observation) -> chex.Array:
return jax.random.uniform(
Expand All @@ -132,7 +134,9 @@ def test_target_detection(env: SearchAndRescue, key: chex.PRNGKey) -> None:
searchers=AgentState(
pos=jnp.array([[0.5, 0.5]]), heading=jnp.array([jnp.pi]), speed=jnp.array([0.0])
),
targets=TargetState(pos=jnp.array([[0.54, 0.5]]), found=jnp.array([False])),
targets=TargetState(
pos=jnp.array([[0.54, 0.5]]), vel=jnp.zeros((1, 2)), found=jnp.array([False])
),
key=key,
)
state, timestep = env.step(state, jnp.zeros((1, 2)))
Expand Down Expand Up @@ -188,7 +192,9 @@ def test_multi_target_detection(env: SearchAndRescue, key: chex.PRNGKey) -> None
pos=jnp.array([[0.5, 0.5]]), heading=jnp.array([0.5 * jnp.pi]), speed=jnp.array([0.0])
),
targets=TargetState(
pos=jnp.array([[0.54, 0.5], [0.46, 0.5]]), found=jnp.array([False, False])
pos=jnp.array([[0.54, 0.5], [0.46, 0.5]]),
vel=jnp.zeros((2, 2)),
found=jnp.array([False, False]),
),
key=key,
)
Expand Down
5 changes: 4 additions & 1 deletion jumanji/environments/swarms/search_and_rescue/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,14 @@ def __call__(self, key: chex.PRNGKey, searcher_params: AgentParams) -> State:
target_pos = jax.random.uniform(
target_key, (self.num_targets, 2), minval=0.0, maxval=self.env_size
)
target_vel = jnp.zeros((self.num_targets, 2))

state = State(
searchers=searcher_state,
targets=TargetState(
pos=target_pos, found=jnp.full((self.num_targets,), False, dtype=bool)
pos=target_pos,
vel=target_vel,
found=jnp.full((self.num_targets,), False, dtype=bool),
),
key=key,
)
Expand Down
12 changes: 6 additions & 6 deletions jumanji/environments/swarms/search_and_rescue/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
class ObservationFn(abc.ABC):
def __init__(
self,
view_shape: Tuple[int, ...],
num_channels: int,
num_vision: int,
vision_range: float,
view_angle: float,
Expand All @@ -38,14 +38,14 @@ def __init__(
Base class for observation function mapping state to individual agent views.
Args:
view_shape: Individual agent view shape.
num_channels: Number of channels in agent view.
num_vision: Size of vision array.
vision_range: Vision range.
view_angle: Agent view angle (as a fraction of pi).
agent_radius: Agent/target visual radius.
env_size: Environment size.
"""
self.view_shape = view_shape
self.num_channels = num_channels
self.num_vision = num_vision
self.vision_range = vision_range
self.view_angle = view_angle
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(
env_size: Environment size.
"""
super().__init__(
(1, num_vision),
1,
num_vision,
vision_range,
view_angle,
Expand Down Expand Up @@ -199,7 +199,7 @@ def __init__(
self.agent_radius = agent_radius
self.env_size = env_size
super().__init__(
(2, num_vision),
2,
num_vision,
vision_range,
view_angle,
Expand Down Expand Up @@ -333,7 +333,7 @@ def __init__(
self.agent_radius = agent_radius
self.env_size = env_size
super().__init__(
(3, num_vision),
3,
num_vision,
vision_range,
view_angle,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def test_searcher_view(
searchers=AgentState(
pos=searcher_positions, heading=searcher_headings, speed=searcher_speed
),
targets=TargetState(pos=jnp.zeros((1, 2)), found=jnp.zeros((1, 2), dtype=bool)),
targets=TargetState(
pos=jnp.zeros((1, 2)), vel=jnp.zeros((1, 2)), found=jnp.zeros((1, 2), dtype=bool)
),
key=key,
)

Expand Down Expand Up @@ -164,7 +166,9 @@ def test_search_and_target_view_searchers(
searchers=AgentState(
pos=searcher_positions, heading=searcher_headings, speed=searcher_speed
),
targets=TargetState(pos=jnp.zeros((1, 2)), found=jnp.zeros((1,), dtype=bool)),
targets=TargetState(
pos=jnp.zeros((1, 2)), vel=jnp.zeros((1, 2)), found=jnp.zeros((1,), dtype=bool)
),
key=key,
)

Expand Down Expand Up @@ -241,6 +245,7 @@ def test_search_and_target_view_targets(
searchers=AgentState(pos=searcher_position, heading=searcher_heading, speed=searcher_speed),
targets=TargetState(
pos=target_position,
vel=jnp.zeros_like(target_position),
found=target_found,
),
key=key,
Expand Down Expand Up @@ -328,6 +333,7 @@ def test_search_and_all_target_view_targets(
searchers=AgentState(pos=searcher_position, heading=searcher_heading, speed=searcher_speed),
targets=TargetState(
pos=target_position,
vel=jnp.zeros_like(target_position),
found=target_found,
),
key=key,
Expand Down
24 changes: 21 additions & 3 deletions jumanji/environments/swarms/search_and_rescue/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class RewardFn(abc.ABC):
"""Abstract class for `SearchAndRescue` rewards."""

@abc.abstractmethod
def __call__(self, found_targets: chex.Array) -> chex.Array:
def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array:
"""The reward function used in the `SearchAndRescue` environment.
Args:
Expand All @@ -41,7 +41,7 @@ class SharedRewardFn(RewardFn):
can receive rewards for detecting multiple targets.
"""

def __call__(self, found_targets: chex.Array) -> chex.Array:
def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array:
rewards = found_targets.astype(float)
norms = jnp.sum(rewards, axis=0)[jnp.newaxis]
rewards = jnp.where(norms > 0, rewards / norms, rewards)
Expand All @@ -57,7 +57,25 @@ class IndividualRewardFn(RewardFn):
even if a target is detected by multiple agents.
"""

def __call__(self, found_targets: chex.Array) -> chex.Array:
def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array:
rewards = found_targets.astype(float)
rewards = jnp.sum(rewards, axis=1)
return rewards


class SharedScaledRewardFn(RewardFn):
"""
Calculate per agent rewards from detected targets
Targets detected by multiple agents share rewards. Agents
can receive rewards for detecting multiple targets.
Rewards are scaled by the current time step.
"""

def __call__(self, found_targets: chex.Array, step: int, time_limit: int) -> chex.Array:
rewards = found_targets.astype(float)
norms = jnp.sum(rewards, axis=0)[jnp.newaxis]
rewards = jnp.where(norms > 0, rewards / norms, rewards)
rewards = jnp.sum(rewards, axis=1)
scale = (time_limit - step) / time_limit
return scale * rewards
Loading

0 comments on commit 9e8ac5c

Please sign in to comment.