-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Search & Rescue Multi-Agent Environment #259
Open
zombie-einstein
wants to merge
24
commits into
instadeepai:main
Choose a base branch
from
zombie-einstein:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
3597d9e
feat: Implement predator prey env (#1)
zombie-einstein c955320
Merge branch 'instadeepai:main' into main
zombie-einstein 6b34657
Merge branch 'main' into main
sash-a 988339b
fix: PR fixes (#2)
zombie-einstein a0fe7a5
Merge branch 'instadeepai:main' into main
zombie-einstein b4cce01
style: Run updated pre-commit
zombie-einstein cb6d88d
refactor: Consolidate predator prey type
zombie-einstein 06de3a0
feat: Implement search and rescue (#3)
zombie-einstein 34beab6
fix: PR fixes (#4)
zombie-einstein f5fa659
Merge branch 'instadeepai:main' into main
zombie-einstein 072db18
refactor: PR fixes (#5)
zombie-einstein 162a74d
feat: Allow variable environment dimensions (#6)
zombie-einstein 4996869
Merge branch 'main' into main
zombie-einstein 6322f61
fix: Locate targets in single pass (#8)
zombie-einstein 4ba7688
Merge branch 'instadeepai:main' into main
zombie-einstein 9a654b9
feat: training and customisable observations (#7)
zombie-einstein 5021e20
feat: view all targets (#9)
zombie-einstein c5c7b85
Merge branch 'instadeepai:main' into main
zombie-einstein 13ffb84
Merge branch 'instadeepai:main' into main
zombie-einstein 9e8ac5c
feat: Scaled rewards and target velocities (#10)
zombie-einstein 5c509c7
Pass shape information to timesteps (#11)
zombie-einstein 8acf242
test: extend tests and docs (#12)
zombie-einstein 1792aa6
fix: unpin jax requirement
zombie-einstein 1e66e78
Include agent positions in observation (#13)
zombie-einstein File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
::: jumanji.environments.swarms.search_and_rescue.env.SearchAndRescue | ||
selection: | ||
members: | ||
- __init__ | ||
- reset | ||
- step | ||
- observation_spec | ||
- action_spec | ||
- reward_spec | ||
- render | ||
- animate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# 🚁 Search & Rescue | ||
|
||
[//]: # (TODO: Add animated plot) | ||
|
||
Multi-agent environment, modelling a group of agents searching a 2d environment | ||
for multiple targets. Agents are individually rewarded for finding a target | ||
that has not previously been detected. | ||
|
||
Each agent visualises a local region around itself, represented as a simple segmented | ||
view of locations of other agents and targets in the vicinity. The environment | ||
is updated in the following sequence: | ||
|
||
- The velocity of searching agents are updated, and consequently their positions. | ||
- The positions of targets are updated. | ||
- Targets within detection range and an agents view cone are marked as found. | ||
- Agents are rewarded for locating previously unfound targets. | ||
- Local views of the environment are generated for each search agent. | ||
|
||
The agents are allotted a fixed number of steps to locate the targets. The search | ||
space is a uniform square space, wrapped at the boundaries. | ||
|
||
Many aspects of the environment can be customised: | ||
|
||
- Agent observations can include targets as well as other searcher agents. | ||
- Rewards can be shared by agents, or can be treated completely individually for individual agents. | ||
- Target dynamics can be customised to model various search scenarios. | ||
|
||
## Observations | ||
|
||
- `searcher_views`: jax array (float) of shape `(num_searchers, channels, num_vision)`. | ||
Each agent generates an independent observation, an array of values representing the distance | ||
along a ray from the agent to the nearest neighbour or target, with each cell representing a | ||
ray angle (with `num_vision` rays evenly distributed over the agents field of vision). | ||
For example if an agent sees another agent straight ahead and `num_vision = 5` then | ||
the observation array could be | ||
|
||
``` | ||
[-1.0, -1.0, 0.5, -1.0, -1.0] | ||
``` | ||
|
||
where `-1.0` indicates there is no agents along that ray, and `0.5` is the normalised | ||
distance to the other agent. Channels in the segmented view are used to differentiate | ||
between different agents/targets and can be customised. By default, the view has three | ||
channels representing other agents, found targets, and unfound targets. | ||
- `targets_remaining`: float in the range `[0, 1]`. The normalised number of targets | ||
remaining to be detected (i.e. 1.0 when no targets have been found). | ||
- `Step`: int in the range `[0, time_limit]`. The current simulation step. | ||
|
||
## Actions | ||
|
||
Jax array (float) of `(num_searchers, 2)` in the range `[-1, 1]`. Each entry in the | ||
array represents an update of each agents velocity in the next step. Searching agents | ||
update their velocity each step by rotating and accelerating/decelerating, where the | ||
values are `[rotation, acceleration]`. Values are clipped to the range `[-1, 1]` | ||
and then scaled by max rotation and acceleration parameters, i.e. the new values each | ||
step are given by | ||
|
||
``` | ||
heading = heading + max_rotation * action[0] | ||
``` | ||
|
||
and speed | ||
|
||
``` | ||
speed = speed + max_acceleration * action[1] | ||
``` | ||
|
||
Once applied, agent speeds are clipped to velocities within a fixed range of speeds. | ||
|
||
## Rewards | ||
|
||
Jax array (float) of `(num_searchers,)`. Rewards are generated for each agent individually. | ||
Agents are rewarded +1 for locating a target that has not already been detected. It is possible | ||
for multiple agents to detect a target inside a step, as such rewards can either be shared | ||
by the locating agents, or each agent can get the full reward. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2022 InstaDeep Ltd. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright 2022 InstaDeep Ltd. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
# Copyright 2022 InstaDeep Ltd. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import List, Tuple | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import matplotlib | ||
import matplotlib.pyplot as plt | ||
import pytest | ||
|
||
from jumanji.environments.swarms.common import types, updates, viewer | ||
|
||
|
||
@pytest.fixture | ||
def params() -> types.AgentParams: | ||
return types.AgentParams( | ||
max_rotate=0.5, | ||
max_accelerate=0.01, | ||
min_speed=0.01, | ||
max_speed=0.05, | ||
view_angle=0.5, | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"heading, speed, actions, expected", | ||
[ | ||
[0.0, 0.01, [1.0, 0.0], (0.5 * jnp.pi, 0.01)], | ||
[0.0, 0.01, [-1.0, 0.0], (1.5 * jnp.pi, 0.01)], | ||
[jnp.pi, 0.01, [1.0, 0.0], (1.5 * jnp.pi, 0.01)], | ||
[jnp.pi, 0.01, [-1.0, 0.0], (0.5 * jnp.pi, 0.01)], | ||
[1.75 * jnp.pi, 0.01, [1.0, 0.0], (0.25 * jnp.pi, 0.01)], | ||
[0.0, 0.01, [0.0, 1.0], (0.0, 0.02)], | ||
[0.0, 0.01, [0.0, -1.0], (0.0, 0.01)], | ||
[0.0, 0.02, [0.0, -1.0], (0.0, 0.01)], | ||
[0.0, 0.05, [0.0, -1.0], (0.0, 0.04)], | ||
[0.0, 0.05, [0.0, 1.0], (0.0, 0.05)], | ||
], | ||
) | ||
def test_velocity_update( | ||
params: types.AgentParams, | ||
heading: float, | ||
speed: float, | ||
actions: List[float], | ||
expected: Tuple[float, float], | ||
) -> None: | ||
key = jax.random.PRNGKey(101) | ||
|
||
state = types.AgentState( | ||
pos=jnp.zeros((1, 2)), | ||
heading=jnp.array([heading]), | ||
speed=jnp.array([speed]), | ||
) | ||
actions = jnp.array([actions]) | ||
|
||
new_heading, new_speed = updates.update_velocity(key, params, (actions, state)) | ||
|
||
assert jnp.isclose(new_heading[0], expected[0]) | ||
assert jnp.isclose(new_speed[0], expected[1]) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"pos, heading, speed, expected, env_size", | ||
[ | ||
[[0.0, 0.5], 0.0, 0.1, [0.1, 0.5], 1.0], | ||
[[0.0, 0.5], jnp.pi, 0.1, [0.9, 0.5], 1.0], | ||
[[0.5, 0.0], 0.5 * jnp.pi, 0.1, [0.5, 0.1], 1.0], | ||
[[0.5, 0.0], 1.5 * jnp.pi, 0.1, [0.5, 0.9], 1.0], | ||
[[0.4, 0.2], 0.0, 0.2, [0.1, 0.2], 0.5], | ||
[[0.1, 0.2], jnp.pi, 0.2, [0.4, 0.2], 0.5], | ||
[[0.2, 0.4], 0.5 * jnp.pi, 0.2, [0.2, 0.1], 0.5], | ||
[[0.2, 0.1], 1.5 * jnp.pi, 0.2, [0.2, 0.4], 0.5], | ||
], | ||
) | ||
def test_move( | ||
pos: List[float], heading: float, speed: float, expected: List[float], env_size: float | ||
) -> None: | ||
pos = jnp.array(pos) | ||
new_pos = updates.move(pos, heading, speed, env_size) | ||
|
||
assert jnp.allclose(new_pos, jnp.array(expected)) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"pos, heading, speed, actions, expected_pos, expected_heading, expected_speed, env_size", | ||
[ | ||
[[0.0, 0.5], 0.0, 0.01, [0.0, 0.0], [0.01, 0.5], 0.0, 0.01, 1.0], | ||
[[0.5, 0.0], 0.0, 0.01, [1.0, 0.0], [0.5, 0.01], 0.5 * jnp.pi, 0.01, 1.0], | ||
[[0.5, 0.0], 0.0, 0.01, [-1.0, 0.0], [0.5, 0.99], 1.5 * jnp.pi, 0.01, 1.0], | ||
[[0.0, 0.5], 0.0, 0.01, [0.0, 1.0], [0.02, 0.5], 0.0, 0.02, 1.0], | ||
[[0.0, 0.5], 0.0, 0.01, [0.0, -1.0], [0.01, 0.5], 0.0, 0.01, 1.0], | ||
[[0.0, 0.5], 0.0, 0.05, [0.0, 1.0], [0.05, 0.5], 0.0, 0.05, 1.0], | ||
[[0.495, 0.25], 0.0, 0.01, [0.0, 0.0], [0.005, 0.25], 0.0, 0.01, 0.5], | ||
[[0.25, 0.005], 1.5 * jnp.pi, 0.01, [0.0, 0.0], [0.25, 0.495], 1.5 * jnp.pi, 0.01, 0.5], | ||
], | ||
) | ||
def test_state_update( | ||
params: types.AgentParams, | ||
pos: List[float], | ||
heading: float, | ||
speed: float, | ||
actions: List[float], | ||
expected_pos: List[float], | ||
expected_heading: float, | ||
expected_speed: float, | ||
env_size: float, | ||
) -> None: | ||
key = jax.random.PRNGKey(101) | ||
|
||
state = types.AgentState( | ||
pos=jnp.array([pos]), | ||
heading=jnp.array([heading]), | ||
speed=jnp.array([speed]), | ||
) | ||
actions = jnp.array([actions]) | ||
|
||
new_state = updates.update_state(key, env_size, params, state, actions) | ||
|
||
assert isinstance(new_state, types.AgentState) | ||
assert jnp.allclose(new_state.pos, jnp.array([expected_pos])) | ||
assert jnp.allclose(new_state.heading, jnp.array([expected_heading])) | ||
assert jnp.allclose(new_state.speed, jnp.array([expected_speed])) | ||
|
||
|
||
def test_view_reduction() -> None: | ||
view_a = jnp.array([-1.0, -1.0, 0.2, 0.2, 0.5]) | ||
view_b = jnp.array([-1.0, 0.2, -1.0, 0.5, 0.2]) | ||
result = updates.view_reduction(view_a, view_b) | ||
assert jnp.allclose(result, jnp.array([-1.0, 0.2, 0.2, 0.2, 0.2])) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"pos, view_angle, env_size, expected", | ||
[ | ||
[[0.05, 0.0], 0.5, 1.0, [-1.0, -1.0, 0.5, -1.0, -1.0]], | ||
[[0.0, 0.05], 0.5, 1.0, [0.5, -1.0, -1.0, -1.0, -1.0]], | ||
[[0.0, 0.95], 0.5, 1.0, [-1.0, -1.0, -1.0, -1.0, 0.5]], | ||
[[0.95, 0.0], 0.5, 1.0, [-1.0, -1.0, -1.0, -1.0, -1.0]], | ||
[[0.05, 0.0], 0.25, 1.0, [-1.0, -1.0, 0.5, -1.0, -1.0]], | ||
[[0.0, 0.05], 0.25, 1.0, [-1.0, -1.0, -1.0, -1.0, -1.0]], | ||
[[0.0, 0.95], 0.25, 1.0, [-1.0, -1.0, -1.0, -1.0, -1.0]], | ||
[[0.01, 0.0], 0.5, 1.0, [-1.0, 0.1, 0.1, 0.1, -1.0]], | ||
[[0.0, 0.45], 0.5, 1.0, [4.5, -1.0, -1.0, -1.0, -1.0]], | ||
[[0.0, 0.45], 0.5, 0.5, [-1.0, -1.0, -1.0, -1.0, 0.5]], | ||
], | ||
) | ||
def test_view(pos: List[float], view_angle: float, env_size: float, expected: List[float]) -> None: | ||
state_a = types.AgentState( | ||
pos=jnp.zeros((2,)), | ||
heading=0.0, | ||
speed=0.0, | ||
) | ||
|
||
state_b = types.AgentState( | ||
pos=jnp.array(pos), | ||
heading=0.0, | ||
speed=0.0, | ||
) | ||
|
||
obs = updates.view( | ||
None, (view_angle, 0.02), state_a, state_b, n_view=5, i_range=0.1, env_size=env_size | ||
) | ||
assert jnp.allclose(obs, jnp.array(expected)) | ||
|
||
|
||
def test_viewer_utils() -> None: | ||
f, ax = plt.subplots() | ||
f, ax = viewer.format_plot(f, ax, (1.0, 1.0)) | ||
|
||
assert isinstance(f, matplotlib.figure.Figure) | ||
assert isinstance(ax, matplotlib.axes.Axes) | ||
|
||
state = types.AgentState( | ||
pos=jnp.zeros((3, 2)), | ||
heading=jnp.zeros((3,)), | ||
speed=jnp.zeros((3,)), | ||
) | ||
|
||
quiver = viewer.draw_agents(ax, state, "red") | ||
|
||
assert isinstance(quiver, matplotlib.quiver.Quiver) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright 2022 InstaDeep Ltd. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import TYPE_CHECKING | ||
|
||
if TYPE_CHECKING: | ||
from dataclasses import dataclass | ||
else: | ||
from chex import dataclass | ||
|
||
import chex | ||
|
||
|
||
@dataclass(frozen=True) | ||
class AgentParams: | ||
""" | ||
max_rotate: Max angle an agent can rotate during a step (a fraction of pi) | ||
max_accelerate: Max change in speed during a step | ||
min_speed: Minimum agent speed | ||
max_speed: Maximum agent speed | ||
view_angle: Agent view angle, as a fraction of pi either side of its heading | ||
""" | ||
|
||
max_rotate: float | ||
max_accelerate: float | ||
min_speed: float | ||
max_speed: float | ||
view_angle: float | ||
|
||
|
||
@dataclass | ||
class AgentState: | ||
""" | ||
State of multiple agents of a single type | ||
|
||
pos: 2d position of the (centre of the) agents | ||
heading: Heading of the agents (in radians) | ||
speed: Speed of the agents | ||
""" | ||
|
||
pos: chex.Array # (num_agents, 2) | ||
heading: chex.Array # (num_agents,) | ||
speed: chex.Array # (num_agents,) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth registering the environment with different vision models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean fully versus partially observable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I think you mean the different observation functions. I'd say we should aim to use the observation that only visualizes targets and searchers within their vision cones. If we see a good training curve with this then that should be the default.
In general I'd like to have a set of scenarios for most/all environments in jumanji (see #248). So it would be cool to think of a set (3-4) easy/hard envs and we can register those. If they happen to have different observation models that's fine
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I was picturing something along these lines, like the easy one is where un-found targets are visible, and then harder versions use the version with hidden targets.