Skip to content

Commit

Permalink
Add support for different vector autoreset modes (#1227)
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts authored Nov 28, 2024
1 parent 7e2062d commit 8a46c3a
Show file tree
Hide file tree
Showing 21 changed files with 854 additions and 63 deletions.
3 changes: 2 additions & 1 deletion gymnasium/envs/classic_control/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from gymnasium import logger, spaces
from gymnasium.envs.classic_control import utils
from gymnasium.error import DependencyNotInstalled
from gymnasium.vector import VectorEnv
from gymnasium.vector import AutoresetMode, VectorEnv
from gymnasium.vector.utils import batch_space


Expand Down Expand Up @@ -355,6 +355,7 @@ class CartPoleVectorEnv(VectorEnv):
metadata = {
"render_modes": ["rgb_array"],
"render_fps": 50,
"autoreset_mode": AutoresetMode.NEXT_STEP,
}

def __init__(
Expand Down
3 changes: 2 additions & 1 deletion gymnasium/envs/functional_jax_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from gymnasium.envs.registration import EnvSpec
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import seeding
from gymnasium.vector import AutoresetMode
from gymnasium.vector.utils import batch_space


Expand Down Expand Up @@ -115,7 +116,7 @@ def __init__(
"""Initialize the environment from a FuncEnv."""
super().__init__()
if metadata is None:
metadata = {}
metadata = {"autoreset_mode": AutoresetMode.NEXT_STEP}
self.func_env = func_env
self.num_envs = num_envs

Expand Down
8 changes: 7 additions & 1 deletion gymnasium/envs/phys2d/cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import EzPickle
from gymnasium.vector import AutoresetMode


RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock"] # type: ignore # noqa: F821
Expand Down Expand Up @@ -272,7 +273,12 @@ def __init__(self, render_mode: str | None = None, **kwargs: Any):
class CartPoleJaxVectorEnv(FunctionalJaxVectorEnv, EzPickle):
"""Jax-based implementation of the vectorized CartPole environment."""

metadata = {"render_modes": ["rgb_array"], "render_fps": 50, "jax": True}
metadata = {
"render_modes": ["rgb_array"],
"render_fps": 50,
"jax": True,
"autoreset_mode": AutoresetMode.NEXT_STEP,
}

def __init__(
self,
Expand Down
8 changes: 7 additions & 1 deletion gymnasium/envs/phys2d/pendulum.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import EzPickle
from gymnasium.vector import AutoresetMode


RenderStateType = Tuple["pygame.Surface", "pygame.time.Clock", Optional[float]] # type: ignore # noqa: F821
Expand Down Expand Up @@ -225,7 +226,12 @@ def get_default_params(self, **kwargs) -> PendulumParams:
class PendulumJaxEnv(FunctionalJaxEnv, EzPickle):
"""Jax-based pendulum environment using the functional version as base."""

metadata = {"render_modes": ["rgb_array"], "render_fps": 30, "jax": True}
metadata = {
"render_modes": ["rgb_array"],
"render_fps": 30,
"jax": True,
"autoreset_mode": AutoresetMode.NEXT_STEP,
}

def __init__(self, render_mode: str | None = None, **kwargs: Any):
"""Constructor where the kwargs are passed to the base environment to modify the parameters."""
Expand Down
11 changes: 11 additions & 0 deletions gymnasium/envs/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import gymnasium as gym
from gymnasium import Env, Wrapper, error, logger
from gymnasium.logger import warn
from gymnasium.vector import AutoresetMode


if sys.version_info < (3, 10):
Expand Down Expand Up @@ -976,6 +978,15 @@ def create_single_env() -> Env:
copied_id_spec.kwargs["wrappers"] = wrappers
env.unwrapped.spec = copied_id_spec

if "autoreset_mode" not in env.metadata:
warn(
f"The VectorEnv ({env}) is missing AutoresetMode metadata, metadata={env.metadata}"
)
elif not isinstance(env.metadata["autoreset_mode"], AutoresetMode):
warn(
f"The VectorEnv ({env}) metadata['autoreset_mode'] is not an instance of AutoresetMode, {type(env.metadata['autoreset_mode'])}."
)

return env


Expand Down
2 changes: 2 additions & 0 deletions gymnasium/envs/tabular/blackjack.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import EzPickle, seeding
from gymnasium.vector import AutoresetMode
from gymnasium.wrappers import HumanRendering


Expand Down Expand Up @@ -239,6 +240,7 @@ class BlackjackFunctional(
metadata = {
"render_modes": ["rgb_array"],
"render_fps": 4,
"autoreseet-mode": AutoresetMode.NEXT_STEP,
}

def transition(
Expand Down
2 changes: 2 additions & 0 deletions gymnasium/envs/tabular/cliffwalking.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from gymnasium.error import DependencyNotInstalled
from gymnasium.experimental.functional import ActType, FuncEnv, StateType
from gymnasium.utils import EzPickle
from gymnasium.vector import AutoresetMode
from gymnasium.wrappers import HumanRendering


Expand Down Expand Up @@ -136,6 +137,7 @@ class CliffWalkingFunctional(
metadata = {
"render_modes": ["rgb_array"],
"render_fps": 4,
"autoreset_mode": AutoresetMode.NEXT_STEP,
}

def transition(
Expand Down
2 changes: 2 additions & 0 deletions gymnasium/vector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from gymnasium.vector.async_vector_env import AsyncVectorEnv
from gymnasium.vector.sync_vector_env import SyncVectorEnv
from gymnasium.vector.vector_env import (
AutoresetMode,
VectorActionWrapper,
VectorEnv,
VectorObservationWrapper,
Expand All @@ -21,4 +22,5 @@
"SyncVectorEnv",
"AsyncVectorEnv",
"utils",
"AutoresetMode",
]
87 changes: 78 additions & 9 deletions gymnasium/vector/async_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
read_from_shared_memory,
write_to_shared_memory,
)
from gymnasium.vector.vector_env import ArrayType, VectorEnv
from gymnasium.vector.vector_env import ArrayType, AutoresetMode, VectorEnv


__all__ = ["AsyncVectorEnv", "AsyncState"]
Expand Down Expand Up @@ -101,6 +101,7 @@ def __init__(
| None
) = None,
observation_mode: str | Space = "same",
autoreset_mode: str | AutoresetMode = AutoresetMode.NEXT_STEP,
):
"""Vectorized environment that runs multiple environments in parallel.
Expand All @@ -120,6 +121,7 @@ def __init__(
'different' defines that there can be multiple observation spaces with different parameters though requires the same shape and dtype,
warning, may raise unexpected errors. Passing a ``Tuple[Space, Space]`` object allows defining a custom ``single_observation_space`` and
``observation_space``, warning, may raise unexpected errors.
autoreset_mode: The Autoreset Mode used, see todo for more details.
Warnings:
worker is an advanced mode option. It provides a high degree of flexibility and a high chance
Expand All @@ -135,7 +137,15 @@ def __init__(
self.env_fns = env_fns
self.shared_memory = shared_memory
self.copy = copy
self.context = context
self.daemon = daemon
self.worker = worker
self.observation_mode = observation_mode
self.autoreset_mode = (
autoreset_mode
if isinstance(autoreset_mode, AutoresetMode)
else AutoresetMode(autoreset_mode)
)

self.num_envs = len(env_fns)

Expand All @@ -145,6 +155,7 @@ def __init__(

# As we support `make_vec(spec)` then we can't include a `spec = dummy_env.spec` as this doesn't guarantee we can actual recreate the vector env.
self.metadata = dummy_env.metadata
self.metadata["autoreset_mode"] = self.autoreset_mode
self.render_mode = dummy_env.render_mode

self.single_action_space = dummy_env.action_space
Expand Down Expand Up @@ -211,6 +222,7 @@ def __init__(
parent_pipe,
_obs_buffer,
self.error_queue,
self.autoreset_mode,
),
)

Expand Down Expand Up @@ -287,9 +299,32 @@ def reset_async(
str(self._state.value),
)

for pipe, env_seed in zip(self.parent_pipes, seed):
env_kwargs = {"seed": env_seed, "options": options}
pipe.send(("reset", env_kwargs))
if options is not None and "reset_mask" in options:
reset_mask = options.pop("reset_mask")
assert isinstance(
reset_mask, np.ndarray
), f"`options['reset_mask': mask]` must be a numpy array, got {type(reset_mask)}"
assert reset_mask.shape == (
self.num_envs,
), f"`options['reset_mask': mask]` must have shape `({self.num_envs},)`, got {reset_mask.shape}"
assert (
reset_mask.dtype == np.bool_
), f"`options['reset_mask': mask]` must have `dtype=np.bool_`, got {reset_mask.dtype}"
assert np.any(
reset_mask
), f"`options['reset_mask': mask]` must contain a boolean array, got reset_mask={reset_mask}"

for pipe, env_seed, env_reset in zip(self.parent_pipes, seed, reset_mask):
if env_reset:
env_kwargs = {"seed": env_seed, "options": options}
pipe.send(("reset", env_kwargs))
else:
pipe.send(("reset-noop", None))
else:
for pipe, env_seed in zip(self.parent_pipes, seed):
env_kwargs = {"seed": env_seed, "options": options}
pipe.send(("reset", env_kwargs))

self._state = AsyncState.WAITING_RESET

def reset_wait(
Expand Down Expand Up @@ -688,11 +723,13 @@ def _async_worker(
parent_pipe: Connection,
shared_memory: multiprocessing.Array | dict[str, Any] | tuple[Any, ...],
error_queue: Queue,
autoreset_mode: AutoresetMode,
):
env = env_fn()
observation_space = env.observation_space
action_space = env.action_space
autoreset = False
observation = None

parent_pipe.close()

Expand All @@ -709,19 +746,51 @@ def _async_worker(
observation = None
autoreset = False
pipe.send(((observation, info), True))
elif command == "reset-noop":
pipe.send(((observation, {}), True))
elif command == "step":
if autoreset:
observation, info = env.reset()
reward, terminated, truncated = 0, False, False
else:
if autoreset_mode == AutoresetMode.NEXT_STEP:
if autoreset:
observation, info = env.reset()
reward, terminated, truncated = 0, False, False
else:
(
observation,
reward,
terminated,
truncated,
info,
) = env.step(data)
autoreset = terminated or truncated
elif autoreset_mode == AutoresetMode.SAME_STEP:
(
observation,
reward,
terminated,
truncated,
info,
) = env.step(data)
autoreset = terminated or truncated

if terminated or truncated:
reset_observation, reset_info = env.reset()

info = {
"final_info": info,
"final_obs": observation,
**reset_info,
}
observation = reset_observation
elif autoreset_mode == AutoresetMode.DISABLED:
assert autoreset is False
(
observation,
reward,
terminated,
truncated,
info,
) = env.step(data)
else:
raise ValueError(f"Unexpected autoreset_mode: {autoreset_mode}")

if shared_memory:
write_to_shared_memory(
Expand Down
Loading

0 comments on commit 8a46c3a

Please sign in to comment.