Skip to content

Commit

Permalink
chore: use new jaxmarl fork with unpinned jax
Browse files Browse the repository at this point in the history
  • Loading branch information
WiemKhlifi committed Nov 7, 2024
1 parent 36a9c8d commit 0adf61d
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 15 deletions.
6 changes: 1 addition & 5 deletions mava/utils/make_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,7 @@

# Registry mapping environment names directly to the corresponding wrapper classes.
_matrax_registry = {"Matrax": MatraxWrapper}
_jaxmarl_registry = {
"Smax": SmaxWrapper,
"MaBrax": MabraxWrapper,
"MPE": MPEWrapper,
}
_jaxmarl_registry = {"Smax": SmaxWrapper, "MaBrax": MabraxWrapper, "MPE": MPEWrapper}
_gigastep_registry = {"Gigastep": GigastepWrapper}


Expand Down
11 changes: 1 addition & 10 deletions mava/wrappers/jaxmarl.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,12 +253,6 @@ def _create_observation(
return Observation(**obs_data)

def observation_spec(self) -> specs.Spec:
if isinstance(self._env, SimpleSpreadMPE):
obs, _ = self._env.reset(jax.random.PRNGKey(0))
# The shape provided in the `observation_spaces` isn't always the same as
# that given after constructing the full agent's obs.
self._env.observation_spaces["agent_0"].shape = obs["agent_0"].shape

agents_view = jaxmarl_space_to_jumanji_spec(merge_space(self._env.observation_spaces))

action_mask = specs.BoundedArray(
Expand Down Expand Up @@ -438,10 +432,7 @@ def action_dim(self) -> chex.Array:
@cached_property
def state_size(self) -> chex.Array:
"Get the state size of the global observation"
# The shape provided in the `observation_spaces` isn't always the same as
# that given after constructing the full agent's obs
obs, _ = self._env.reset(jax.random.PRNGKey(0))
return obs["agent_0"].shape[0] * self.num_agents
return self._env.observation_space(self.agents[0]).shape[0] * self.num_agents

def action_mask(self, wrapped_env_state: Any) -> Array:
"""Get action mask for each agent."""
Expand Down

0 comments on commit 0adf61d

Please sign in to comment.