Skip to content

Commit

Permalink
chore: cleaning based on review
Browse files Browse the repository at this point in the history
  • Loading branch information
WiemKhlifi committed Nov 28, 2024
1 parent ac56b73 commit 59c53f3
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 26 deletions.
4 changes: 2 additions & 2 deletions mava/configs/env/connector.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ defaults:

env_name: Connector # Used for logging purposes.

# Choose whether to aggregate the list of individual rewards and use the team reward (default setting) OR use_individual_rewards=True.
use_individual_rewards: False # If True, use the list of individual rewards.
# Choose whether to aggregate individual rewards into a shared team reward or not
aggregate_rewards: False # If True, use the list of individual rewards.

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
Expand Down
4 changes: 2 additions & 2 deletions mava/configs/env/lbf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ defaults:

env_name: LevelBasedForaging # Used for logging purposes.

# Choose whether to aggregate the list of individual rewards and use the team reward (default setting) OR use_individual_rewards=True.
use_individual_rewards: False # If True, use the list of individual rewards.
# Choose whether to aggregate individual rewards into a shared team reward or not
aggregate_rewards: False # If True, use the list of individual rewards.

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
Expand Down
4 changes: 2 additions & 2 deletions mava/configs/env/vector-connector.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ defaults:

env_name: VectorConnector # Used for logging purposes.

# Choose whether to aggregate the list of individual rewards and use the team reward (default setting) OR use_individual_rewards=True.
use_individual_rewards: True # If True, use the list of individual rewards.
# Choose whether to aggregate individual rewards into a shared team reward or not
aggregate_rewards: True # If True, use the list of individual rewards.

# Defines the metric that will be used to evaluate the performance of the agent.
# This metric is returned at the end of an experiment and can be used for hyperparameter tuning.
Expand Down
34 changes: 16 additions & 18 deletions mava/wrappers/jumanji.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,8 @@
from mava.types import Observation, ObservationGlobalState, State


def aggregate_rewards(
reward: chex.Array, num_agents: int, use_individual_rewards: bool = False
) -> chex.Array:
def aggregate_rewards(reward: chex.Array, num_agents: int) -> chex.Array:
"""Aggregate individual rewards across agents."""
if use_individual_rewards:
# Returns a list of individual rewards that will be used as is.
return reward

# Aggregate the list of individual rewards and use a single team_reward.
team_reward = jnp.sum(reward)
return jnp.repeat(team_reward, num_agents)

Expand Down Expand Up @@ -187,11 +180,11 @@ def __init__(
self,
env: LevelBasedForaging,
add_global_state: bool = False,
use_individual_rewards: bool = False,
aggregate_rewards: bool = False,
):
super().__init__(env, add_global_state)
self._env: LevelBasedForaging
self._use_individual_rewards = use_individual_rewards
self._aggregate_rewards = aggregate_rewards

def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]:
"""Modify the timestep for Level-Based Foraging environment and update
Expand All @@ -204,7 +197,9 @@ def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]:
step_count=jnp.repeat(timestep.observation.step_count, self.num_agents),
)
# Whether or not aggregate the list of individual rewards.
reward = aggregate_rewards(timestep.reward, self.num_agents, self._use_individual_rewards)
reward = timestep.reward
if self._aggregate_rewards:
reward = aggregate_rewards(reward, self.num_agents)

return timestep.replace(observation=modified_observation, reward=reward)

Expand Down Expand Up @@ -249,11 +244,11 @@ class ConnectorWrapper(JumanjiMarlWrapper):
"""

def __init__(
self, env: Connector, add_global_state: bool = False, use_individual_rewards: bool = False
self, env: Connector, add_global_state: bool = False, aggregate_rewards: bool = False
):
super().__init__(env, add_global_state)
self._env: Connector
self._use_individual_rewards = use_individual_rewards
self._aggregate_rewards = aggregate_rewards
self.agent_ids = jnp.arange(self.num_agents)

def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]:
Expand Down Expand Up @@ -290,7 +285,9 @@ def create_agents_view(grid: chex.Array) -> chex.Array:
extras = timestep.extras | {"won_episode": timestep.extras["ratio_connections"] == 1.0}

# Whether or not aggregate the list of individual rewards.
reward = aggregate_rewards(timestep.reward, self.num_agents, self._use_individual_rewards)
reward = timestep.reward
if self._aggregate_rewards:
reward = aggregate_rewards(reward, self.num_agents)
return timestep.replace(observation=Observation(**obs_data), reward=reward, extras=extras)

def get_global_state(self, obs: Observation) -> chex.Array:
Expand Down Expand Up @@ -368,12 +365,12 @@ class VectorConnectorWrapper(JumanjiMarlWrapper):
"""

def __init__(
self, env: Connector, add_global_state: bool = False, use_individual_rewards: bool = False
self, env: Connector, add_global_state: bool = False, aggregate_rewards: bool = False
):
self.fov = 2
super().__init__(env, add_global_state)
self._env: Connector
self._use_individual_rewards = use_individual_rewards
self._aggregate_rewards = aggregate_rewards
self.agent_ids = jnp.arange(self.num_agents)

def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]:
Expand Down Expand Up @@ -439,8 +436,9 @@ def _create_one_agent_view(i: int) -> chex.Array:
# The episode is won if all agents have connected.
extras = timestep.extras | {"won_episode": timestep.extras["ratio_connections"] == 1.0}

# Whether or not aggregate the list of individual rewards.
reward = aggregate_rewards(timestep.reward, self.num_agents, self._use_individual_rewards)
reward = timestep.reward
if self._aggregate_rewards:
reward = aggregate_rewards(reward, self.num_agents)
return timestep.replace(observation=Observation(**obs_data), reward=reward, extras=extras)

@cached_property
Expand Down
4 changes: 2 additions & 2 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ id-marl-eval
jax==0.4.30
jaxlib==0.4.30
jaxmarl @ git+https://github.com/RuanJohn/JaxMARL@unpin-jax # This only unpins the version of Jax.
jumanji
jumanji>= 1.1.0
lbforaging
matrax @ git+https://github.com/instadeepai/matrax
matrax>= 0.0.5
mujoco==3.1.3
mujoco-mjx==3.1.3
neptune
Expand Down

0 comments on commit 59c53f3

Please sign in to comment.