diff --git a/deformable_gym/envs/mujoco/base_mjenv.py b/deformable_gym/envs/mujoco/base_mjenv.py index 02ffdcc..469922f 100644 --- a/deformable_gym/envs/mujoco/base_mjenv.py +++ b/deformable_gym/envs/mujoco/base_mjenv.py @@ -60,6 +60,8 @@ class BaseMJEnv(gym.Env, ABC): The space representing possible actions that can be taken by the agent. """ + metadata = {"render_modes": ["human", "rgb_array", "depth_array"]} + def __init__( self, robot_name: str, @@ -89,16 +91,38 @@ def __init__( self.mocap = MocapControl() self.init_frame = init_frame self.max_sim_time = max_sim_time + if render_mode is not None: + assert render_mode in self.metadata["render_modes"] self.render_mode = render_mode self.camera_name = camera_name self.camera_id = camera_id - self.renderer = MujocoRenderer( - self.model, self.data, default_cam_config - ) + self.renderer = self._get_renderer(default_cam_config) self.observation_space = self._get_observation_space() self.action_space = self._get_action_space() + def _get_renderer(self, default_cam_config: Dict[str, Any] | None): + """ + Returns the renderer object for the environment. + + Returns: + MujocoRenderer or mujoco.viewer: The renderer object for the environment. + """ + if self.render_mode == "human": + renderer = mujoco.viewer.launch_passive( + self.model, self.data, show_left_ui=False, show_right_ui=False + ) + if default_cam_config is not None: + for attr, value in default_cam_config.items(): + setattr(renderer.cam, attr, value) + elif ( + self.render_mode == "rgb_array" or self.render_mode == "depth_array" + ): + renderer = MujocoRenderer(self.model, self.data, default_cam_config) + else: + renderer = None + return renderer + def _get_action_space(self) -> spaces.Box: """ Defines the action space for the environment based on the robot's control range. @@ -220,13 +244,16 @@ def render(self) -> None: """ Render a frame from the MuJoCo simulation as specified by the render_mode. """ - - return self.renderer.render( - self.render_mode, self.camera_id, self.camera_name - ) + if self.render_mode == "human": + self.renderer.sync() + elif ( + self.render_mode == "rgb_array" or self.render_mode == "depth_array" + ): + return self.renderer.render(self.render_mode) def close(self) -> None: """ Close the environment. """ - self.renderer.close() + if self.renderer is not None: + self.renderer.close()