Skip to content

Commit

Permalink
feat: view all targets (#9)
Browse files Browse the repository at this point in the history
* Add observation including all targets

* Consistent test module names

* Use CNN embedding
  • Loading branch information
zombie-einstein authored Dec 9, 2024
1 parent 9a654b9 commit 5021e20
Show file tree
Hide file tree
Showing 6 changed files with 253 additions and 20 deletions.
138 changes: 136 additions & 2 deletions jumanji/environments/swarms/search_and_rescue/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def __call__(self, state: State) -> chex.Array:
return searcher_views[:, jnp.newaxis]


def target_view(
def found_target_view(
_key: chex.PRNGKey,
params: Tuple[float, float],
searcher: AgentState,
Expand Down Expand Up @@ -235,7 +235,7 @@ def __call__(self, state: State) -> chex.Array:
env_size=self.env_size,
)
target_views = spatial(
target_view,
found_target_view,
reduction=view_reduction,
default=-jnp.ones((self.num_vision,)),
include_self=False,
Expand All @@ -253,3 +253,137 @@ def __call__(self, state: State) -> chex.Array:
env_size=self.env_size,
)
return jnp.hstack([searcher_views[:, jnp.newaxis], target_views[:, jnp.newaxis]])


def all_target_view(
_key: chex.PRNGKey,
params: Tuple[float, float],
searcher: AgentState,
target: TargetState,
*,
n_view: int,
i_range: float,
env_size: float,
) -> chex.Array:
"""
Return view of a target, dependent on target status.
This function is intended to be mapped over agents-targets
by Esquilax.
Args:
_key: Dummy random key (required by Esquilax).
params: View angle and target visual radius.
searcher: Searcher agent state
target: Target state
n_view: Number of value sin view array.
i_range: Vision range
env_size: Environment size
Returns:
Segmented agent view of target.
"""
view_angle, agent_radius = params
rays = jnp.linspace(
-view_angle * jnp.pi,
view_angle * jnp.pi,
n_view,
endpoint=True,
)
d, left, right = angular_width(
searcher.pos,
target.pos,
searcher.heading,
i_range,
agent_radius,
env_size,
)
ray_checks = jnp.logical_and(left < rays, rays < right)
checks_a = jnp.logical_and(target.found, ray_checks)
checks_b = jnp.logical_and(~target.found, ray_checks)
obs = [jnp.where(checks_a, d, -1.0), jnp.where(checks_b, d, -1.0)]
obs = jnp.vstack(obs)
return obs


class AgentAndAllTargetObservationFn(ObservationFn):
def __init__(
self,
num_vision: int,
vision_range: float,
view_angle: float,
agent_radius: float,
env_size: float,
) -> None:
"""
Vision model that contains other agents, and found targets.
Searchers and targets are visualised as individual channels.
Targets are only included if they have been located already.
Args:
num_vision: Size of vision array.
vision_range: Vision range.
view_angle: Agent view angle (as a fraction of pi).
agent_radius: Agent/target visual radius.
env_size: Environment size.
"""
self.vision_range = vision_range
self.view_angle = view_angle
self.agent_radius = agent_radius
self.env_size = env_size
super().__init__(
(3, num_vision),
num_vision,
vision_range,
view_angle,
agent_radius,
env_size,
)

def __call__(self, state: State) -> chex.Array:
"""
Generate agent view/observation from state
Args:
state: Current simulation state
Returns:
Array of individual agent views
"""
searcher_views = spatial(
view,
reduction=view_reduction,
default=-jnp.ones((self.num_vision,)),
include_self=False,
i_range=self.vision_range,
dims=self.env_size,
)(
state.key,
(self.view_angle, self.agent_radius),
state.searchers,
state.searchers,
pos=state.searchers.pos,
n_view=self.num_vision,
i_range=self.vision_range,
env_size=self.env_size,
)
target_views = spatial(
all_target_view,
reduction=view_reduction,
default=-jnp.ones((2, self.num_vision)),
include_self=False,
i_range=self.vision_range,
dims=self.env_size,
)(
state.key,
(self.view_angle, self.agent_radius),
state.searchers,
state.targets,
pos=state.searchers.pos,
pos_b=state.targets.pos,
n_view=self.num_vision,
i_range=self.vision_range,
env_size=self.env_size,
)
return jnp.hstack([searcher_views[:, jnp.newaxis], target_views])
87 changes: 87 additions & 0 deletions jumanji/environments/swarms/search_and_rescue/observations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,90 @@ def test_search_and_target_view_targets(
expected = expected.at[0, 1, idx].set(val)

assert jnp.all(jnp.isclose(obs, expected))


@pytest.mark.parametrize(
"searcher_position, searcher_heading, target_position, target_found, env_size, view_updates",
[
# Target out of view range
([0.8, 0.5], jnp.pi, [0.2, 0.5], True, 1.0, []),
# Target in view and found
([0.25, 0.5], jnp.pi, [0.2, 0.5], True, 1.0, [(1, 5, 0.25)]),
# Target in view but not found
([0.25, 0.5], jnp.pi, [0.2, 0.5], False, 1.0, [(2, 5, 0.25)]),
# Observed around wrapped edge found
(
[0.025, 0.5],
jnp.pi,
[0.975, 0.5],
True,
1.0,
[(1, 5, 0.25)],
),
# Observed around wrapped edge not found
(
[0.025, 0.5],
jnp.pi,
[0.975, 0.5],
False,
1.0,
[(2, 5, 0.25)],
),
# Observed around wrapped edge of smaller env
(
[0.025, 0.25],
jnp.pi,
[0.475, 0.25],
True,
0.5,
[(1, 5, 0.25)],
),
],
)
def test_search_and_all_target_view_targets(
key: chex.PRNGKey,
env: SearchAndRescue,
searcher_position: List[float],
searcher_heading: float,
target_position: List[float],
target_found: bool,
env_size: float,
view_updates: List[Tuple[int, int, float]],
) -> None:
"""
Test agent+target view model generates expected array with different
configurations of targets only.
"""

searcher_position = jnp.array([searcher_position])
searcher_heading = jnp.array([searcher_heading])
searcher_speed = jnp.zeros((1,))
target_position = jnp.array([target_position])
target_found = jnp.array([target_found])

state = State(
searchers=AgentState(pos=searcher_position, heading=searcher_heading, speed=searcher_speed),
targets=TargetState(
pos=target_position,
found=target_found,
),
key=key,
)

observe_fn = observations.AgentAndAllTargetObservationFn(
num_vision=11,
vision_range=VISION_RANGE,
view_angle=VIEW_ANGLE,
agent_radius=0.01,
env_size=env_size,
)

obs = observe_fn(state)
assert obs.shape == (1, 3, observe_fn.num_vision)

expected = jnp.full((1, 3, observe_fn.num_vision), -1.0)

for i, idx, val in view_updates:
expected = expected.at[0, i, idx].set(val)

assert jnp.all(jnp.isclose(obs, expected))
48 changes: 30 additions & 18 deletions jumanji/training/networks/search_and_rescue/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import chex
import haiku as hk
import jax
import jax.numpy as jnp

from jumanji.environments.swarms.search_and_rescue import SearchAndRescue
Expand Down Expand Up @@ -50,7 +51,9 @@ def make_actor_critic_search_and_rescue(
"""
n_actions = prod(search_and_rescue.action_spec.shape)
parametric_action_distribution = ContinuousActionSpaceNormalTanhDistribution(n_actions)
policy_network = make_actor_network(layers=layers, n_actions=n_actions)
policy_network = make_actor_network(
layers=layers, n_agents=search_and_rescue.generator.num_searchers, n_actions=n_actions
)
value_network = make_critic_network(layers=layers)

return ActorCriticNetworks(
Expand All @@ -60,43 +63,52 @@ def make_actor_critic_search_and_rescue(
)


def embedding(x: chex.Array) -> chex.Array:
n_channels = x.shape[-2]
x = hk.Conv1D(2 * n_channels, 3, data_format="NCW")(x)
x = jax.nn.relu(x)
x = hk.MaxPool(2, 2, "SAME", channel_axis=-2)(x)
return x


def make_critic_network(layers: Sequence[int]) -> FeedForwardNetwork:
# Shape names:
# B: batch size
# N: number of agents
# C: Observation channels
# O: observation size

def network_fn(observation: Observation) -> Union[chex.Array, Tuple[chex.Array, chex.Array]]:
views = observation.searcher_views # (B, N, O)
batch_size = views.shape[0]
views = views.reshape(batch_size, -1) # (B, N * O)
value = hk.nets.MLP([*layers, 1])(views) # (B,)
x = observation.searcher_views # (B, N, C, O)
x = hk.vmap(embedding, split_rng=False)(x)
x = hk.Flatten()(x)
value = hk.nets.MLP([*layers, 1])(x) # (B,)
return jnp.squeeze(value, axis=-1)

init, apply = hk.without_apply_rng(hk.transform(network_fn))
return FeedForwardNetwork(init=init, apply=apply)


def make_actor_network(layers: Sequence[int], n_actions: int) -> FeedForwardNetwork:
def make_actor_network(layers: Sequence[int], n_agents: int, n_actions: int) -> FeedForwardNetwork:
# Shape names:
# B: batch size
# N: number of agents
# C: Observation channels
# O: observation size
# A: Number of actions

def log_std_params(x: chex.Array) -> chex.Array:
return hk.get_parameter("log_stds", shape=x.shape, init=hk.initializers.Constant(0.1))

def network_fn(observation: Observation) -> Union[chex.Array, Tuple[chex.Array, chex.Array]]:
views = observation.searcher_views # (B, N, O)
batch_size = views.shape[0]
n_agents = views.shape[1]
views = views.reshape((batch_size, -1)) # (B, N * 0)
means = hk.nets.MLP([*layers, n_agents * n_actions])(views) # (B, N * A)
means = means.reshape(batch_size, n_agents, n_actions) # (B, N, A)

log_stds = hk.get_parameter(
"log_stds", shape=(n_agents * n_actions,), init=hk.initializers.Constant(0.1)
) # (N * A,)
log_stds = jnp.broadcast_to(log_stds, (batch_size, n_agents * n_actions)) # (B, N * A)
log_stds = log_stds.reshape(batch_size, n_agents, n_actions) # (B, N, A)
x = observation.searcher_views # (B, N, C, O)
x = hk.vmap(embedding, split_rng=False)(x)
x = hk.Flatten()(x)

means = hk.nets.MLP([*layers, n_agents * n_actions])(x) # (B, N * A)
means = hk.Reshape(output_shape=(n_agents, n_actions))(means) # (B, N, A)

log_stds = hk.vmap(log_std_params, split_rng=False)(means) # (B, N, A)

return means, log_stds

Expand Down

0 comments on commit 5021e20

Please sign in to comment.