Skip to content

Commit

Permalink
Merge branch 'instadeepai:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
zombie-einstein authored Nov 15, 2024
2 parents 34beab6 + 7979c36 commit f5fa659
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 30 deletions.
33 changes: 27 additions & 6 deletions jumanji/environments/routing/connector/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ class Connector(Environment[State, specs.MultiDiscreteArray, Observation]):
- can take the values [0,1,2,3,4] which correspond to [No Op, Up, Right, Down, Left].
- each value in the array corresponds to an agent's action.
- reward: jax array (float) of shape ():
- dense: reward is 1 for each successful connection on that step. Additionally,
each pair of points that have not connected receives a penalty reward of -0.03.
- reward: jax array (float) of shape (num_agents,):
- dense: for each agent the reward is 1 for each successful connection on that step.
Additionally, each pair of points that have not connected receives a
penalty reward of -0.03.
- episode termination:
- all agents either can't move (no available actions) or have connected to their target.
Expand Down Expand Up @@ -142,7 +143,7 @@ def reset(self, key: chex.PRNGKey) -> Tuple[State, TimeStep[Observation]]:
step_count=state.step_count,
)
extras = self._get_extras(state)
timestep = restart(observation=observation, extras=extras)
timestep = restart(observation=observation, extras=extras, shape=(self.num_agents,))
return state, timestep

def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observation]]:
Expand Down Expand Up @@ -171,19 +172,23 @@ def step(self, state: State, action: chex.Array) -> Tuple[State, TimeStep[Observ
grid=grid, action_mask=action_mask, step_count=new_state.step_count
)

done = jnp.all(jax.vmap(connected_or_blocked)(agents, action_mask))
done = jax.vmap(connected_or_blocked)(agents, action_mask)
discount = (1 - done).astype(float)
extras = self._get_extras(new_state)
timestep = jax.lax.cond(
done | (new_state.step_count >= self.time_limit),
jnp.all(done) | (new_state.step_count >= self.time_limit),
lambda: termination(
reward=reward,
observation=observation,
extras=extras,
shape=(self.num_agents,),
),
lambda: transition(
reward=reward,
observation=observation,
extras=extras,
discount=discount,
shape=(self.num_agents,),
),
)

Expand Down Expand Up @@ -362,3 +367,19 @@ def action_spec(self) -> specs.MultiDiscreteArray:
dtype=jnp.int32,
name="action",
)

@cached_property
def reward_spec(self) -> specs.Array:
"""Returns: a reward per agent."""
return specs.Array(shape=(self.num_agents,), dtype=float, name="reward")

@cached_property
def discount_spec(self) -> specs.BoundedArray:
"""Returns: discount per agent."""
return specs.BoundedArray(
shape=(self.num_agents,),
dtype=float,
minimum=0.0,
maximum=1.0,
name="discount",
)
12 changes: 6 additions & 6 deletions jumanji/environments/routing/connector/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def test_connector__reset(connector: Connector, key: jax.random.PRNGKey) -> None
assert all(is_head_on_grid(state.agents, state.grid))
assert all(is_target_on_grid(state.agents, state.grid))

assert timestep.discount == 1.0
assert timestep.reward == 0.0
assert jnp.allclose(timestep.discount, jnp.ones((connector.num_agents,)))
assert jnp.allclose(timestep.reward, jnp.zeros((connector.num_agents,)))
assert timestep.step_type == StepType.FIRST


Expand Down Expand Up @@ -94,7 +94,7 @@ def test_connector__step_connected(
chex.assert_trees_all_equal(real_state2, state2)

assert timestep.step_type == StepType.LAST
assert jnp.array_equal(timestep.discount, jnp.asarray(0))
assert jnp.array_equal(timestep.discount, jnp.zeros(connector.num_agents))
reward = connector._reward_fn(real_state1, action2, real_state2)
assert jnp.array_equal(timestep.reward, reward)

Expand Down Expand Up @@ -146,7 +146,7 @@ def test_connector__step_blocked(

assert jnp.array_equal(state.grid, expected_grid)
assert timestep.step_type == StepType.LAST
assert jnp.array_equal(timestep.discount, jnp.asarray(0))
assert jnp.array_equal(timestep.discount, jnp.zeros(connector.num_agents))

assert all(is_head_on_grid(state.agents, state.grid))
assert all(is_target_on_grid(state.agents, state.grid))
Expand All @@ -165,12 +165,12 @@ def test_connector__step_horizon(connector: Connector, state: State) -> None:
state, timestep = step_fn(state, actions)

assert timestep.step_type != StepType.LAST
assert jnp.array_equal(timestep.discount, jnp.asarray(1))
assert jnp.array_equal(timestep.discount, jnp.ones(connector.num_agents))

# step 5
state, timestep = step_fn(state, actions)
assert timestep.step_type == StepType.LAST
assert jnp.array_equal(timestep.discount, jnp.asarray(0))
assert jnp.array_equal(timestep.discount, jnp.zeros(connector.num_agents))


def test_connector__step_agents_collision(
Expand Down
2 changes: 1 addition & 1 deletion jumanji/environments/routing/connector/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ def __call__(
~state.agents.connected & next_state.agents.connected, float
)
timestep_rewards = self.timestep_reward * jnp.asarray(~state.agents.connected, float)
return jnp.sum(connected_rewards + timestep_rewards)
return connected_rewards + timestep_rewards
23 changes: 12 additions & 11 deletions jumanji/environments/routing/connector/reward_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,27 @@ def test_dense_reward(

# Reward of moving between the same states should be 0.
reward = dense_reward_fn(state, jnp.array([0, 0, 0]), state)
chex.assert_rank(reward, 0)
assert jnp.isclose(reward, jnp.asarray(timestep_reward * 3))
chex.assert_rank(reward, 1)
assert jnp.allclose(reward, jnp.array([timestep_reward] * 3))

# Reward for no agents finished to 2 agents finished.
reward = dense_reward_fn(state, action1, state1)
chex.assert_rank(reward, 0)
expected_reward = connected_reward * 2 + timestep_reward * 3
assert jnp.isclose(reward, expected_reward)
chex.assert_rank(reward, 1)
expected_reward = jnp.array([connected_reward, 0, connected_reward]) + timestep_reward
assert jnp.allclose(reward, expected_reward)

# Reward for some agents finished to all agents finished.
reward = dense_reward_fn(state1, action2, state2)
chex.assert_rank(reward, 0)
assert jnp.isclose(reward, jnp.array(connected_reward + timestep_reward))
chex.assert_rank(reward, 1)
expected_reward = jnp.array([0, connected_reward + timestep_reward, 0])
assert jnp.allclose(reward, expected_reward)

# Reward for none finished to all finished
reward = dense_reward_fn(state, action1, state2)
chex.assert_rank(reward, 0)
assert jnp.isclose(reward, jnp.array((connected_reward + timestep_reward) * 3))
chex.assert_rank(reward, 1)
assert jnp.allclose(reward, jnp.array([connected_reward + timestep_reward] * 3))

# Reward of all finished to all finished.
reward = dense_reward_fn(state2, jnp.zeros(3), state2)
chex.assert_rank(reward, 0)
assert jnp.isclose(reward, jnp.zeros(1))
chex.assert_rank(reward, 1)
assert jnp.allclose(reward, jnp.zeros(1))
2 changes: 1 addition & 1 deletion jumanji/training/setup_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def setup_logger(cfg: DictConfig) -> Logger:

def _make_raw_env(cfg: DictConfig) -> Environment:
env = jumanji.make(cfg.env.registered_version)
if cfg.env.name in {"lbf"}:
if cfg.env.name in {"lbf", "connector"}:
# Convert a multi-agent environment to a single-agent environment
env = MultiToSingleWrapper(env)
return env
Expand Down
23 changes: 18 additions & 5 deletions jumanji/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def restart(
observation: Observation,
extras: Optional[Dict] = None,
shape: Union[int, Sequence[int]] = (),
dtype: Union[jnp.dtype, type] = float,
) -> TimeStep:
"""Returns a `TimeStep` with `step_type` set to `StepType.FIRST`.
Expand All @@ -107,15 +108,17 @@ def restart(
shape: optional parameter to specify the shape of the rewards and discounts.
Allows multi-agent environment compatibility. Defaults to () for
scalar reward and discount.
dtype: Optional parameter to specify the data type of the rewards and discounts.
Defaults to `float`.
Returns:
TimeStep identified as a reset.
"""
extras = extras or {}
return TimeStep(
step_type=StepType.FIRST,
reward=jnp.zeros(shape, dtype=float),
discount=jnp.ones(shape, dtype=float),
reward=jnp.zeros(shape, dtype=dtype),
discount=jnp.ones(shape, dtype=dtype),
observation=observation,
extras=extras,
)
Expand All @@ -127,6 +130,7 @@ def transition(
discount: Optional[Array] = None,
extras: Optional[Dict] = None,
shape: Union[int, Sequence[int]] = (),
dtype: Union[jnp.dtype, type] = float,
) -> TimeStep:
"""Returns a `TimeStep` with `step_type` set to `StepType.MID`.
Expand All @@ -141,11 +145,13 @@ def transition(
shape: optional parameter to specify the shape of the rewards and discounts.
Allows multi-agent environment compatibility. Defaults to () for
scalar reward and discount.
dtype: Optional parameter to specify the data type of the discounts. Defaults
to `float`.
Returns:
TimeStep identified as a transition.
"""
discount = discount if discount is not None else jnp.ones(shape, dtype=float)
discount = discount if discount is not None else jnp.ones(shape, dtype=dtype)
extras = extras or {}
return TimeStep(
step_type=StepType.MID,
Expand All @@ -161,6 +167,7 @@ def termination(
observation: Observation,
extras: Optional[Dict] = None,
shape: Union[int, Sequence[int]] = (),
dtype: Union[jnp.dtype, type] = float,
) -> TimeStep:
"""Returns a `TimeStep` with `step_type` set to `StepType.LAST`.
Expand All @@ -174,6 +181,8 @@ def termination(
shape: optional parameter to specify the shape of the rewards and discounts.
Allows multi-agent environment compatibility. Defaults to () for
scalar reward and discount.
dtype: Optional parameter to specify the data type of the discounts. Defaults
to `float`.
Returns:
TimeStep identified as the termination of an episode.
Expand All @@ -182,7 +191,7 @@ def termination(
return TimeStep(
step_type=StepType.LAST,
reward=reward,
discount=jnp.zeros(shape, dtype=float),
discount=jnp.zeros(shape, dtype=dtype),
observation=observation,
extras=extras,
)
Expand All @@ -194,6 +203,7 @@ def truncation(
discount: Optional[Array] = None,
extras: Optional[Dict] = None,
shape: Union[int, Sequence[int]] = (),
dtype: Union[jnp.dtype, type] = float,
) -> TimeStep:
"""Returns a `TimeStep` with `step_type` set to `StepType.LAST`.
Expand All @@ -208,10 +218,13 @@ def truncation(
shape: optional parameter to specify the shape of the rewards and discounts.
Allows multi-agent environment compatibility. Defaults to () for
scalar reward and discount.
dtype: Optional parameter to specify the data type of the discounts. Defaults
to `float`.
Returns:
TimeStep identified as the truncation of an episode.
"""
discount = discount if discount is not None else jnp.ones(shape, dtype=float)
discount = discount if discount is not None else jnp.ones(shape, dtype=dtype)
extras = extras or {}
return TimeStep(
step_type=StepType.LAST,
Expand Down

0 comments on commit f5fa659

Please sign in to comment.