Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Search & Rescue Multi-Agent Environment #259

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
3597d9e
feat: Implement predator prey env (#1)
zombie-einstein Nov 4, 2024
c955320
Merge branch 'instadeepai:main' into main
zombie-einstein Nov 4, 2024
6b34657
Merge branch 'main' into main
sash-a Nov 4, 2024
988339b
fix: PR fixes (#2)
zombie-einstein Nov 5, 2024
a0fe7a5
Merge branch 'instadeepai:main' into main
zombie-einstein Nov 5, 2024
b4cce01
style: Run updated pre-commit
zombie-einstein Nov 6, 2024
cb6d88d
refactor: Consolidate predator prey type
zombie-einstein Nov 7, 2024
06de3a0
feat: Implement search and rescue (#3)
zombie-einstein Nov 11, 2024
34beab6
fix: PR fixes (#4)
zombie-einstein Nov 14, 2024
f5fa659
Merge branch 'instadeepai:main' into main
zombie-einstein Nov 15, 2024
072db18
refactor: PR fixes (#5)
zombie-einstein Nov 19, 2024
162a74d
feat: Allow variable environment dimensions (#6)
zombie-einstein Nov 19, 2024
4996869
Merge branch 'main' into main
zombie-einstein Nov 22, 2024
6322f61
fix: Locate targets in single pass (#8)
zombie-einstein Nov 23, 2024
4ba7688
Merge branch 'instadeepai:main' into main
zombie-einstein Nov 28, 2024
9a654b9
feat: training and customisable observations (#7)
zombie-einstein Dec 7, 2024
5021e20
feat: view all targets (#9)
zombie-einstein Dec 9, 2024
c5c7b85
Merge branch 'instadeepai:main' into main
zombie-einstein Dec 9, 2024
13ffb84
Merge branch 'instadeepai:main' into main
zombie-einstein Dec 9, 2024
9e8ac5c
feat: Scaled rewards and target velocities (#10)
zombie-einstein Dec 11, 2024
5c509c7
Pass shape information to timesteps (#11)
zombie-einstein Dec 11, 2024
8acf242
test: extend tests and docs (#12)
zombie-einstein Dec 11, 2024
1792aa6
fix: unpin jax requirement
zombie-einstein Dec 12, 2024
1e66e78
Include agent positions in observation (#13)
zombie-einstein Dec 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/api/environments/search_and_rescue.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
::: jumanji.environments.swarms.search_and_rescue.env.SearchAndRescue
selection:
members:
- __init__
- reset
- step
- observation_spec
- action_spec
- reward_spec
- render
- animate
67 changes: 67 additions & 0 deletions docs/environments/search_and_rescue.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# 🚁 Search & Rescue

[//]: # (TODO: Add animated plot)

Multi-agent environment, modelling a group of agents searching the environment
for multiple targets. Agents are individually rewarded for finding a target
that has not previously been detected.

Each agent visualises a local region around it, creating a simple segmented view
of locations of other agents in the vicinity. The environment is updated in the
following sequence:

- The velocity of searching agents are updated, and consequently their positions.
- The positions of targets are updated.
- Agents are rewarded for being within a fixed range of targets, and the target
being within its view cone.
- Targets within detection range and an agents view cone are marked as found.
- Local views of the environment are generated for each search agent.

The agents are allotted a fixed number of steps to locate the targets. The search
space is a uniform space with unit dimensions, and wrapped at the boundaries.

## Observations

- `searcher_views`: jax array (float) of shape `(num_searchers, num_vision)`. Each agent
generates an independent observation, an array of values representing the distance
along a ray from the agent to the nearest neighbour, with each cell representing a
ray angle (with `num_vision` rays evenly distributed over the agents field of vision).
For example if an agent sees another agent straight ahead and `num_vision = 5` then
the observation array could be

```
[-1.0, -1.0, 0.5, -1.0, -1.0]
```

where `-1.0` indicates there is no agents along that ray, and `0.5` is the normalised
distance to the other agent.
- `targets_remaining`: float in the range `[0, 1]`. The normalised number of targets
remaining to be detected (i.e. 1.0 when no targets have been found).
- `time_remaining`: float in the range `[0, 1]`. The normalised number of steps remaining
to locate the targets (i.e. 0.0 at the end of the episode).

## Actions

Jax array (float) of `(num_searchers, 2)` in the range `[-1, 1]`. Each entry in the
array represents an update of each agents velocity in the next step. Searching agents
update their velocity each step by rotating and accelerating/decelerating, where the
values are `[rotation, acceleration]`. Values are clipped to the range `[-1, 1]`
and then scaled by max rotation and acceleration parameters, i.e. the new values each
step are given by

```
heading = heading + max_rotation * action[0]
```

and speed

```
speed = speed + max_acceleration * action[1]
```

Once applied, agent speeds are clipped to velocities within a fixed range of speeds.

## Rewards

Jax array (float) of `(num_searchers,)`. Rewards are generated for each agent individually.
Agents are rewarded 1.0 for locating a target that has not already been detected.
7 changes: 7 additions & 0 deletions jumanji/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,10 @@
# LevelBasedForaging with a random generator with 8 grid size,
# 2 agents and 2 food items and the maximum agent's level is 2.
register(id="LevelBasedForaging-v0", entry_point="jumanji.environments:LevelBasedForaging")

###
# Swarm Environments
###

# Search-and-Rescue environment
register(id="SearchAndRescue-v0", entry_point="jumanji.environments:SearchAndRescue")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth registering the environment with different vision models?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean fully versus partially observable?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I think you mean the different observation functions. I'd say we should aim to use the observation that only visualizes targets and searchers within their vision cones. If we see a good training curve with this then that should be the default.

In general I'd like to have a set of scenarios for most/all environments in jumanji (see #248). So it would be cool to think of a set (3-4) easy/hard envs and we can register those. If they happen to have different observation models that's fine

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I was picturing something along these lines, like the easy one is where un-found targets are visible, and then harder versions use the version with hidden targets.

1 change: 1 addition & 0 deletions jumanji/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from jumanji.environments.routing.snake.env import Snake
from jumanji.environments.routing.sokoban.env import Sokoban
from jumanji.environments.routing.tsp.env import TSP
from jumanji.environments.swarms.search_and_rescue.env import SearchAndRescue


def is_colab() -> bool:
Expand Down
13 changes: 13 additions & 0 deletions jumanji/environments/swarms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
13 changes: 13 additions & 0 deletions jumanji/environments/swarms/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
193 changes: 193 additions & 0 deletions jumanji/environments/swarms/common/common_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Tuple

import jax
import jax.numpy as jnp
import matplotlib
import matplotlib.pyplot as plt
import pytest

from jumanji.environments.swarms.common import types, updates, viewer


@pytest.fixture
def params() -> types.AgentParams:
return types.AgentParams(
max_rotate=0.5,
max_accelerate=0.01,
min_speed=0.01,
max_speed=0.05,
view_angle=0.5,
)


@pytest.mark.parametrize(
"heading, speed, actions, expected",
[
[0.0, 0.01, [1.0, 0.0], (0.5 * jnp.pi, 0.01)],
[0.0, 0.01, [-1.0, 0.0], (1.5 * jnp.pi, 0.01)],
[jnp.pi, 0.01, [1.0, 0.0], (1.5 * jnp.pi, 0.01)],
[jnp.pi, 0.01, [-1.0, 0.0], (0.5 * jnp.pi, 0.01)],
[1.75 * jnp.pi, 0.01, [1.0, 0.0], (0.25 * jnp.pi, 0.01)],
[0.0, 0.01, [0.0, 1.0], (0.0, 0.02)],
[0.0, 0.01, [0.0, -1.0], (0.0, 0.01)],
[0.0, 0.02, [0.0, -1.0], (0.0, 0.01)],
[0.0, 0.05, [0.0, -1.0], (0.0, 0.04)],
[0.0, 0.05, [0.0, 1.0], (0.0, 0.05)],
],
)
def test_velocity_update(
params: types.AgentParams,
heading: float,
speed: float,
actions: List[float],
expected: Tuple[float, float],
) -> None:
key = jax.random.PRNGKey(101)

state = types.AgentState(
pos=jnp.zeros((1, 2)),
heading=jnp.array([heading]),
speed=jnp.array([speed]),
)
actions = jnp.array([actions])

new_heading, new_speed = updates.update_velocity(key, params, (actions, state))

assert jnp.isclose(new_heading[0], expected[0])
assert jnp.isclose(new_speed[0], expected[1])


@pytest.mark.parametrize(
"pos, heading, speed, expected, env_size",
[
[[0.0, 0.5], 0.0, 0.1, [0.1, 0.5], 1.0],
[[0.0, 0.5], jnp.pi, 0.1, [0.9, 0.5], 1.0],
[[0.5, 0.0], 0.5 * jnp.pi, 0.1, [0.5, 0.1], 1.0],
[[0.5, 0.0], 1.5 * jnp.pi, 0.1, [0.5, 0.9], 1.0],
[[0.4, 0.2], 0.0, 0.2, [0.1, 0.2], 0.5],
[[0.1, 0.2], jnp.pi, 0.2, [0.4, 0.2], 0.5],
[[0.2, 0.4], 0.5 * jnp.pi, 0.2, [0.2, 0.1], 0.5],
[[0.2, 0.1], 1.5 * jnp.pi, 0.2, [0.2, 0.4], 0.5],
],
)
def test_move(
pos: List[float], heading: float, speed: float, expected: List[float], env_size: float
) -> None:
pos = jnp.array(pos)
new_pos = updates.move(pos, heading, speed, env_size)

assert jnp.allclose(new_pos, jnp.array(expected))


@pytest.mark.parametrize(
"pos, heading, speed, actions, expected_pos, expected_heading, expected_speed, env_size",
[
[[0.0, 0.5], 0.0, 0.01, [0.0, 0.0], [0.01, 0.5], 0.0, 0.01, 1.0],
[[0.5, 0.0], 0.0, 0.01, [1.0, 0.0], [0.5, 0.01], 0.5 * jnp.pi, 0.01, 1.0],
[[0.5, 0.0], 0.0, 0.01, [-1.0, 0.0], [0.5, 0.99], 1.5 * jnp.pi, 0.01, 1.0],
[[0.0, 0.5], 0.0, 0.01, [0.0, 1.0], [0.02, 0.5], 0.0, 0.02, 1.0],
[[0.0, 0.5], 0.0, 0.01, [0.0, -1.0], [0.01, 0.5], 0.0, 0.01, 1.0],
[[0.0, 0.5], 0.0, 0.05, [0.0, 1.0], [0.05, 0.5], 0.0, 0.05, 1.0],
[[0.495, 0.25], 0.0, 0.01, [0.0, 0.0], [0.005, 0.25], 0.0, 0.01, 0.5],
[[0.25, 0.005], 1.5 * jnp.pi, 0.01, [0.0, 0.0], [0.25, 0.495], 1.5 * jnp.pi, 0.01, 0.5],
],
)
def test_state_update(
params: types.AgentParams,
pos: List[float],
heading: float,
speed: float,
actions: List[float],
expected_pos: List[float],
expected_heading: float,
expected_speed: float,
env_size: float,
) -> None:
key = jax.random.PRNGKey(101)

state = types.AgentState(
pos=jnp.array([pos]),
heading=jnp.array([heading]),
speed=jnp.array([speed]),
)
actions = jnp.array([actions])

new_state = updates.update_state(key, env_size, params, state, actions)

assert isinstance(new_state, types.AgentState)
assert jnp.allclose(new_state.pos, jnp.array([expected_pos]))
assert jnp.allclose(new_state.heading, jnp.array([expected_heading]))
assert jnp.allclose(new_state.speed, jnp.array([expected_speed]))


def test_view_reduction() -> None:
view_a = jnp.array([-1.0, -1.0, 0.2, 0.2, 0.5])
view_b = jnp.array([-1.0, 0.2, -1.0, 0.5, 0.2])
result = updates.view_reduction(view_a, view_b)
assert jnp.allclose(result, jnp.array([-1.0, 0.2, 0.2, 0.2, 0.2]))


@pytest.mark.parametrize(
"pos, view_angle, env_size, expected",
[
[[0.05, 0.0], 0.5, 1.0, [-1.0, -1.0, 0.5, -1.0, -1.0]],
[[0.0, 0.05], 0.5, 1.0, [0.5, -1.0, -1.0, -1.0, -1.0]],
[[0.0, 0.95], 0.5, 1.0, [-1.0, -1.0, -1.0, -1.0, 0.5]],
[[0.95, 0.0], 0.5, 1.0, [-1.0, -1.0, -1.0, -1.0, -1.0]],
[[0.05, 0.0], 0.25, 1.0, [-1.0, -1.0, 0.5, -1.0, -1.0]],
[[0.0, 0.05], 0.25, 1.0, [-1.0, -1.0, -1.0, -1.0, -1.0]],
[[0.0, 0.95], 0.25, 1.0, [-1.0, -1.0, -1.0, -1.0, -1.0]],
[[0.01, 0.0], 0.5, 1.0, [-1.0, 0.1, 0.1, 0.1, -1.0]],
[[0.0, 0.45], 0.5, 1.0, [4.5, -1.0, -1.0, -1.0, -1.0]],
[[0.0, 0.45], 0.5, 0.5, [-1.0, -1.0, -1.0, -1.0, 0.5]],
],
)
def test_view(pos: List[float], view_angle: float, env_size: float, expected: List[float]) -> None:
state_a = types.AgentState(
pos=jnp.zeros((2,)),
heading=0.0,
speed=0.0,
)

state_b = types.AgentState(
pos=jnp.array(pos),
heading=0.0,
speed=0.0,
)

obs = updates.view(
None, (view_angle, 0.02), state_a, state_b, n_view=5, i_range=0.1, env_size=env_size
)
assert jnp.allclose(obs, jnp.array(expected))


def test_viewer_utils() -> None:
f, ax = plt.subplots()
f, ax = viewer.format_plot(f, ax, (1.0, 1.0))

assert isinstance(f, matplotlib.figure.Figure)
assert isinstance(ax, matplotlib.axes.Axes)

state = types.AgentState(
pos=jnp.zeros((3, 2)),
heading=jnp.zeros((3,)),
speed=jnp.zeros((3,)),
)

quiver = viewer.draw_agents(ax, state, "red")

assert isinstance(quiver, matplotlib.quiver.Quiver)
53 changes: 53 additions & 0 deletions jumanji/environments/swarms/common/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2022 InstaDeep Ltd. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from dataclasses import dataclass
else:
from chex import dataclass

import chex


@dataclass(frozen=True)
class AgentParams:
"""
max_rotate: Max angle an agent can rotate during a step (a fraction of pi)
max_accelerate: Max change in speed during a step
min_speed: Minimum agent speed
max_speed: Maximum agent speed
view_angle: Agent view angle, as a fraction of pi either side of its heading
"""

max_rotate: float
max_accelerate: float
min_speed: float
max_speed: float
view_angle: float


@dataclass
class AgentState:
"""
State of multiple agents of a single type

pos: 2d position of the (centre of the) agents
heading: Heading of the agents (in radians)
speed: Speed of the agents
"""

pos: chex.Array # (num_agents, 2)
heading: chex.Array # (num_agents,)
speed: chex.Array # (num_agents,)
Loading
Loading