From 59c53f36b975d1a1ecf52407af0bedffe16ff664 Mon Sep 17 00:00:00 2001 From: WiemKhlifi Date: Thu, 28 Nov 2024 15:42:44 +0100 Subject: [PATCH] chore: cleaning based on review --- mava/configs/env/connector.yaml | 4 +-- mava/configs/env/lbf.yaml | 4 +-- mava/configs/env/vector-connector.yaml | 4 +-- mava/wrappers/jumanji.py | 34 ++++++++++++-------------- requirements/requirements.txt | 4 +-- 5 files changed, 24 insertions(+), 26 deletions(-) diff --git a/mava/configs/env/connector.yaml b/mava/configs/env/connector.yaml index c9a2785d2..7d4bd5285 100644 --- a/mava/configs/env/connector.yaml +++ b/mava/configs/env/connector.yaml @@ -6,8 +6,8 @@ defaults: env_name: Connector # Used for logging purposes. -# Choose whether to aggregate the list of individual rewards and use the team reward (default setting) OR use_individual_rewards=True. -use_individual_rewards: False # If True, use the list of individual rewards. +# Choose whether to aggregate individual rewards into a shared team reward or not +aggregate_rewards: False # If True, use the list of individual rewards. # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. diff --git a/mava/configs/env/lbf.yaml b/mava/configs/env/lbf.yaml index 5952f1f70..bff933140 100644 --- a/mava/configs/env/lbf.yaml +++ b/mava/configs/env/lbf.yaml @@ -6,8 +6,8 @@ defaults: env_name: LevelBasedForaging # Used for logging purposes. -# Choose whether to aggregate the list of individual rewards and use the team reward (default setting) OR use_individual_rewards=True. -use_individual_rewards: False # If True, use the list of individual rewards. +# Choose whether to aggregate individual rewards into a shared team reward or not +aggregate_rewards: False # If True, use the list of individual rewards. # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. diff --git a/mava/configs/env/vector-connector.yaml b/mava/configs/env/vector-connector.yaml index 6a51bcae4..6df2583e0 100644 --- a/mava/configs/env/vector-connector.yaml +++ b/mava/configs/env/vector-connector.yaml @@ -6,8 +6,8 @@ defaults: env_name: VectorConnector # Used for logging purposes. -# Choose whether to aggregate the list of individual rewards and use the team reward (default setting) OR use_individual_rewards=True. -use_individual_rewards: True # If True, use the list of individual rewards. +# Choose whether to aggregate individual rewards into a shared team reward or not +aggregate_rewards: True # If True, use the list of individual rewards. # Defines the metric that will be used to evaluate the performance of the agent. # This metric is returned at the end of an experiment and can be used for hyperparameter tuning. diff --git a/mava/wrappers/jumanji.py b/mava/wrappers/jumanji.py index 33acdcd2d..5bf578845 100644 --- a/mava/wrappers/jumanji.py +++ b/mava/wrappers/jumanji.py @@ -40,15 +40,8 @@ from mava.types import Observation, ObservationGlobalState, State -def aggregate_rewards( - reward: chex.Array, num_agents: int, use_individual_rewards: bool = False -) -> chex.Array: +def aggregate_rewards(reward: chex.Array, num_agents: int) -> chex.Array: """Aggregate individual rewards across agents.""" - if use_individual_rewards: - # Returns a list of individual rewards that will be used as is. - return reward - - # Aggregate the list of individual rewards and use a single team_reward. team_reward = jnp.sum(reward) return jnp.repeat(team_reward, num_agents) @@ -187,11 +180,11 @@ def __init__( self, env: LevelBasedForaging, add_global_state: bool = False, - use_individual_rewards: bool = False, + aggregate_rewards: bool = False, ): super().__init__(env, add_global_state) self._env: LevelBasedForaging - self._use_individual_rewards = use_individual_rewards + self._aggregate_rewards = aggregate_rewards def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: """Modify the timestep for Level-Based Foraging environment and update @@ -204,7 +197,9 @@ def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: step_count=jnp.repeat(timestep.observation.step_count, self.num_agents), ) # Whether or not aggregate the list of individual rewards. - reward = aggregate_rewards(timestep.reward, self.num_agents, self._use_individual_rewards) + reward = timestep.reward + if self._aggregate_rewards: + reward = aggregate_rewards(reward, self.num_agents) return timestep.replace(observation=modified_observation, reward=reward) @@ -249,11 +244,11 @@ class ConnectorWrapper(JumanjiMarlWrapper): """ def __init__( - self, env: Connector, add_global_state: bool = False, use_individual_rewards: bool = False + self, env: Connector, add_global_state: bool = False, aggregate_rewards: bool = False ): super().__init__(env, add_global_state) self._env: Connector - self._use_individual_rewards = use_individual_rewards + self._aggregate_rewards = aggregate_rewards self.agent_ids = jnp.arange(self.num_agents) def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: @@ -290,7 +285,9 @@ def create_agents_view(grid: chex.Array) -> chex.Array: extras = timestep.extras | {"won_episode": timestep.extras["ratio_connections"] == 1.0} # Whether or not aggregate the list of individual rewards. - reward = aggregate_rewards(timestep.reward, self.num_agents, self._use_individual_rewards) + reward = timestep.reward + if self._aggregate_rewards: + reward = aggregate_rewards(reward, self.num_agents) return timestep.replace(observation=Observation(**obs_data), reward=reward, extras=extras) def get_global_state(self, obs: Observation) -> chex.Array: @@ -368,12 +365,12 @@ class VectorConnectorWrapper(JumanjiMarlWrapper): """ def __init__( - self, env: Connector, add_global_state: bool = False, use_individual_rewards: bool = False + self, env: Connector, add_global_state: bool = False, aggregate_rewards: bool = False ): self.fov = 2 super().__init__(env, add_global_state) self._env: Connector - self._use_individual_rewards = use_individual_rewards + self._aggregate_rewards = aggregate_rewards self.agent_ids = jnp.arange(self.num_agents) def modify_timestep(self, timestep: TimeStep) -> TimeStep[Observation]: @@ -439,8 +436,9 @@ def _create_one_agent_view(i: int) -> chex.Array: # The episode is won if all agents have connected. extras = timestep.extras | {"won_episode": timestep.extras["ratio_connections"] == 1.0} - # Whether or not aggregate the list of individual rewards. - reward = aggregate_rewards(timestep.reward, self.num_agents, self._use_individual_rewards) + reward = timestep.reward + if self._aggregate_rewards: + reward = aggregate_rewards(reward, self.num_agents) return timestep.replace(observation=Observation(**obs_data), reward=reward, extras=extras) @cached_property diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 19cd23ae9..4a0c1ef39 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -10,9 +10,9 @@ id-marl-eval jax==0.4.30 jaxlib==0.4.30 jaxmarl @ git+https://github.com/RuanJohn/JaxMARL@unpin-jax # This only unpins the version of Jax. -jumanji +jumanji>= 1.1.0 lbforaging -matrax @ git+https://github.com/instadeepai/matrax +matrax>= 0.0.5 mujoco==3.1.3 mujoco-mjx==3.1.3 neptune