Skip to content

Commit

Permalink
feat: upgrade gym wrapper to gymnasium (#264)
Browse files Browse the repository at this point in the history
  • Loading branch information
sash-a authored Nov 22, 2024
1 parent 7979c36 commit 1556cd9
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 37 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ combinatorial problems.
- 🍬 **Wrappers**: easily connect to your favourite RL frameworks and libraries such as
[Acme](https://github.com/deepmind/acme),
[Stable Baselines3](https://github.com/DLR-RM/stable-baselines3),
[RLlib](https://docs.ray.io/en/latest/rllib/index.html), [OpenAI Gym](https://github.com/openai/gym)
[RLlib](https://docs.ray.io/en/latest/rllib/index.html), [Gymnasium](https://github.com/Farama-Foundation/Gymnasium)
and [DeepMind-Env](https://github.com/deepmind/dm_env) through our `dm_env` and `gym` wrappers.
- 🎓 **Examples**: guides to facilitate Jumanji's adoption and highlight the added value of
JAX-based environments.
Expand Down
10 changes: 5 additions & 5 deletions docs/guides/wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@ next_timestep = dm_env.step(action)
...
```

## Jumanji To Gym
We can also convert our Jumanji environments to a [Gym](https://github.com/openai/gym) environment!
Below is an example of how to convert a Jumanji environment into a Gym environment.
## Jumanji To Gymnasium
We can also convert our Jumanji environments to a [Gymnasium](https://github.com/Farama-Foundation/Gymnasium) environment!
Below is an example of how to convert a Jumanji environment into a Gymnasium environment.

```python
import jumanji.wrappers

env = jumanji.make("Snake-6x6-v0")
gym_env = jumanji.wrappers.JumanjiToGymWrapper(env)

obs = gym_env.reset()
obs, info = gym_env.reset()
action = gym_env.action_space.sample()
observation, reward, done, extra = gym_env.step(action)
observation, reward, term, trunc, info = gym_env.step(action)
...
```

Expand Down
2 changes: 1 addition & 1 deletion jumanji/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

import chex
import dm_env.specs
import gym
import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion jumanji/specs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import chex
import dm_env.specs
import gym.spaces
import gymnasium as gym
import jax.numpy as jnp
import numpy as np
import pytest
Expand Down
41 changes: 18 additions & 23 deletions jumanji/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from __future__ import annotations

from functools import cached_property
from typing import Any, Callable, ClassVar, Dict, Generic, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generic, Optional, Tuple, TypeAlias, Union

import chex
import dm_env.specs
import gym
import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np
Expand All @@ -28,7 +28,7 @@
from jumanji.types import TimeStep

# Type alias that corresponds to ObsType in the Gym API
GymObservation = Any
GymObservation: TypeAlias = chex.ArrayNumpy | Dict[str, Union[chex.ArrayNumpy, "GymObservation"]]


class Wrapper(Environment[State, ActionSpec, Observation], Generic[State, ActionSpec, Observation]):
Expand Down Expand Up @@ -584,10 +584,6 @@ def render(self, state: State) -> Any:
class JumanjiToGymWrapper(gym.Env, Generic[State, ActionSpec, Observation]):
"""A wrapper that converts a Jumanji `Environment` to one that follows the `gym.Env` API."""

# Flag that prevents `gym.register` from misinterpreting the `_step` and
# `_reset` as signs of a deprecated gym Env API.
_gym_disable_underscore_compat: ClassVar[bool] = True

def __init__(
self,
env: Environment[State, ActionSpec, Observation],
Expand Down Expand Up @@ -618,21 +614,21 @@ def reset(key: chex.PRNGKey) -> Tuple[State, Observation, Optional[Dict]]:

def step(
state: State, action: chex.Array
) -> Tuple[State, Observation, chex.Array, bool, Optional[Any]]:
) -> Tuple[State, Observation, chex.Array, chex.Array, chex.Array, Optional[Any]]:
"""Step function of a Jumanji environment to be jitted."""
state, timestep = self._env.step(state, action)
done = jnp.bool_(timestep.last())
return state, timestep.observation, timestep.reward, done, timestep.extras
term = timestep.discount.astype(bool)
trunc = timestep.last().astype(bool)
return state, timestep.observation, timestep.reward, term, trunc, timestep.extras

self._step = jax.jit(step, backend=self.backend)

def reset(
self,
*,
seed: Optional[int] = None,
return_info: bool = False,
options: Optional[dict] = None,
) -> Union[GymObservation, Tuple[GymObservation, Optional[Any]]]:
) -> Tuple[GymObservation, Dict[str, Any]]:
"""Resets the environment to an initial state by starting a new sequence
and returns the first `Observation` of this sequence.
Expand All @@ -648,13 +644,11 @@ def reset(
# Convert the observation to a numpy array or a nested dict thereof
obs = jumanji_to_gym_obs(obs)

if return_info:
info = jax.tree_util.tree_map(np.asarray, extras)
return obs, info
else:
return obs # type: ignore
return obs, jax.device_get(extras)

def step(self, action: chex.ArrayNumpy) -> Tuple[GymObservation, float, bool, Optional[Any]]:
def step(
self, action: chex.ArrayNumpy
) -> Tuple[GymObservation, float, bool, bool, Dict[str, Any]]:
"""Updates the environment according to the action and returns an `Observation`.
Args:
Expand All @@ -667,16 +661,17 @@ def step(self, action: chex.ArrayNumpy) -> Tuple[GymObservation, float, bool, Op
info: contains supplementary information such as metrics.
"""

action = jnp.array(action) # Convert input numpy array to JAX array
self._state, obs, reward, done, extras = self._step(self._state, action)
action_jax = jnp.asarray(action) # Convert input numpy array to JAX array
self._state, obs, reward, term, trunc, extras = self._step(self._state, action_jax)

# Convert to get the correct signature
obs = jumanji_to_gym_obs(obs)
reward = float(reward)
terminated = bool(done)
info = jax.tree_util.tree_map(np.asarray, extras)
terminated = bool(term)
truncated = bool(trunc)
info = jax.device_get(extras)

return obs, reward, terminated, info
return obs, reward, terminated, truncated, info

def seed(self, seed: int = 0) -> None:
"""Function which sets the seed for the environment's random number generator(s).
Expand Down
15 changes: 10 additions & 5 deletions jumanji/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import chex
import dm_env.specs
import gym
import gymnasium as gym
import jax
import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -265,15 +265,18 @@ def test_jumanji_environment_to_gym_env__reset(
self, fake_gym_env: FakeJumanjiToGymWrapper
) -> None:
"""Validates reset function of the wrapped environment."""
observation1 = fake_gym_env.reset()
observation1, info1 = fake_gym_env.reset()
state1 = fake_gym_env._state
observation2 = fake_gym_env.reset()
observation2, info2 = fake_gym_env.reset()
state2 = fake_gym_env._state

# Observation is typically numpy array
assert isinstance(observation1, chex.ArrayNumpy)
assert isinstance(observation2, chex.ArrayNumpy)

assert isinstance(info1, dict)
assert isinstance(info2, dict)

# Check that the observations are equal
chex.assert_trees_all_equal(observation1, observation2)
assert_trees_are_different(state1, state2)
Expand All @@ -282,12 +285,14 @@ def test_jumanji_environment_to_gym_env__step(
self, fake_gym_env: FakeJumanjiToGymWrapper
) -> None:
"""Validates step function of the wrapped environment."""
observation = fake_gym_env.reset()
observation, _ = fake_gym_env.reset()
action = fake_gym_env.action_space.sample()
next_observation, reward, terminated, info = fake_gym_env.step(action)
next_observation, reward, terminated, truncated, info = fake_gym_env.step(action)
assert_trees_are_different(observation, next_observation)
assert isinstance(reward, float)
assert isinstance(terminated, bool)
assert isinstance(truncated, bool)
assert isinstance(info, dict)

def test_jumanji_environment_to_gym_env__observation_space(
self, fake_gym_env: FakeJumanjiToGymWrapper
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
chex>=0.1.3
dm-env>=1.5
gym>=0.22.0
gymnasium>=1.0
huggingface-hub
jax>=0.2.26
matplotlib~=3.7.4
Expand Down

0 comments on commit 1556cd9

Please sign in to comment.