Skip to content

Commit

Permalink
Add doc string
Browse files Browse the repository at this point in the history
  • Loading branch information
xkiixkii committed Aug 16, 2024
1 parent 50b1a0d commit 13e97be
Show file tree
Hide file tree
Showing 5 changed files with 267 additions and 26 deletions.
55 changes: 55 additions & 0 deletions deformable_gym/envs/mujoco/asset_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@


class AssetManager:
"""
The AssetManager class manages the loading and combining of MuJoCo models (robots and objects)
into a single simulation scene. It provides methods to load individual robot or object models,
create complex scenes, and save the resulting XML configuration.
"""

def __init__(self) -> None:
self.assets_dir = os.path.join(os.path.dirname(__file__), "assets")
Expand All @@ -22,6 +27,19 @@ def __init__(self) -> None:
self.objects = OBJECTS

def load_asset(self, name: str) -> mujoco.MjModel:
"""
Loads a asset (robot or object mujoco Model) based on its name.
Args:
name (str): The name of the robot or object to be loaded. Must be a key in either
self.robots or self.objects.
Returns:
mujoco.MjModel: The loaded MuJoCo model.
Raises:
AssertionError: If the specified name is not found in either self.robots or self.objects.
"""
assert (
name in self.robots or name in self.objects
), f"Model {name} not found.\n available: {list(self.robots.keys()) + list(self.objects.keys())}"
Expand All @@ -31,6 +49,20 @@ def load_asset(self, name: str) -> mujoco.MjModel:
return model

def create_scene(self, robot_name: str, obj_name: str) -> str:
"""
Creates an MJCF string representing a MuJoCo scene that includes a robot and an object.
Args:
robot_name (str): The name of the robot to include in the scene. Must be a key in self.robots.
obj_name (str): The name of the object to include in the scene. Must be a key in self.objects.
Returns:
str: A MJCF string representing the combined MuJoCo scene.
Raises:
AssertionError: If the specified robot_name is not found in self.robots.
AssertionError: If the specified obj_name is not found in self.objects.
"""
assert (
robot_name in self.robots
), f"Robot {robot_name} not found.\n available: {list(self.robots.keys())}"
Expand Down Expand Up @@ -63,6 +95,15 @@ def create_scene(self, robot_name: str, obj_name: str) -> str:
# return scene

def _get_full_path(self, file: str) -> str:
"""
Generates the full path to a file within the assets directory.
Args:
file (str): The file name (with extension) for which to generate the full path.
Returns:
str: The full path to the specified file within the assets directory.
"""
return os.path.join(self.assets_dir, file)

@staticmethod
Expand All @@ -72,6 +113,20 @@ def include_mjcf(
*,
meshdir: Optional[str] = None,
) -> str:
"""
Generates an XML string for a MuJoCo scene by including additional MJCF files within a base MJCF file.
Args:
base_path (str): The file path to the base MJCF file.
include_path (Union[str, Sequence[str]]): A string or list of strings representing file paths
to MJCF files to be included in the base file.
meshdir (Optional[str]): A string representing the path to the directory containing mesh files.
If provided, this path is added to the meshdir attribute of the compiler
element in the MJCF XML.
Returns:
str: An XML string representing the combined MJCF file.
"""
tree = ET.parse(base_path)
root = tree.getroot()
if isinstance(include_path, list) or isinstance(include_path, tuple):
Expand Down
45 changes: 42 additions & 3 deletions deformable_gym/envs/mujoco/base_mjenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,41 @@ def __init__(
self.action_space = self._get_action_space()

def _create_scene(self, robot_name: str, obj_name: str) -> str:
manager = AssetManager()
return manager.create_scene(robot_name, obj_name)
"""
Creates the simulation scene by combining the robot and object models.
Args:
robot_name (str): The name of the robot to include in the scene.
obj_name (str): The name of the object to include in the scene.
Returns:
str: The MJCF XML string representing the combined simulation scene.
"""

return AssetManager().create_scene(robot_name, obj_name)

def _get_action_space(self) -> spaces.Box:
"""
Defines the action space for the environment based on the robot's control range.
Returns:
spaces.Box: A continuous space representing the possible actions the agent can take.
"""

n_act = self.robot.nact
low = self.robot.ctrl_range[:, 0].copy()
high = self.robot.ctrl_range[:, 1].copy()
return spaces.Box(low=low, high=high, shape=(n_act,), dtype=np.float64)

def _get_observation_space(self) -> spaces.Box:
"""
Defines the observation space for the environment, including the robot's joint positions
and, optionally, the object's position.
Returns:
spaces.Box: A continuous space representing the state observations available to the agent.
"""

nq = self.robot.nq
low = -np.inf # TODO: joint space range
high = np.inf
Expand All @@ -65,8 +90,16 @@ def reset(
options: Optional[dict] = None,
) -> tuple[NDArray[np.float64], dict[str, Any]]:
"""
Reset the environment to the initial state.
Resets the environment to its initial state.
Args:
seed (Optional[int], optional): A random seed for resetting the environment. Default is None.
options (Optional[dict], optional): Additional options for resetting the environment. Default is None.
Returns:
tuple: A tuple containing the initial observation and an empty info dictionary.
"""

super().reset(seed=seed, options=options)
self.model, _ = mju.load_model_from_string(self.scene)
mujoco.mj_resetData(self.model, self.data)
Expand All @@ -93,6 +126,12 @@ def _set_state(
mujoco.mj_forward(self.model, self.data)

def _load_keyframe(self, frame_name: str):
"""
Loads a predefined keyframe and sets the environment's state to it.
Args:
frame_name (str): The name of the keyframe to load.
"""
frame = self.model.keyframe(frame_name)
qpos = frame.qpos.copy()
qvel = frame.qvel.copy()
Expand Down
71 changes: 57 additions & 14 deletions deformable_gym/envs/mujoco/grasp_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@


class GraspEnv(BaseMJEnv):
"""
A custom MuJoCo environment for a grasping task, where a robot attempts to grasp an object.
"""

def __init__(
self,
Expand All @@ -35,14 +38,18 @@ def __init__(

def reset(self, *, seed=None, options=None) -> Tuple[NDArray, Dict]:
super().reset(seed=seed, options=options)
self.robot.set_pose(
self.model, self.data, self.robot.init_pose[self.object.name]
)
observation = self._get_observation()
info = self._get_info()
return observation, info

def _get_observation(self) -> NDArray:
"""
The observation includes the robot's joint qpos in their generalized coordinates,
and optionally, the object's position.
Returns:
NDArray: A numpy array containing the observation.
"""
robot_qpos = self.robot.get_qpos(self.model, self.data)
if self.observable_object_pos:
obj_pos = self.object.get_current_com(self.data)
Expand All @@ -52,13 +59,19 @@ def _get_observation(self) -> NDArray:
return obs

def _pause_simulation(self, time: float) -> None:
"""Step Mujoco Simulation for a given time.
"""
Steps the MuJoCo simulation for a specified amount of time.
certain constraints of object will be disabled and all joints will be freezed
before stepping the simulation to allow the robot to grasp the object without
external influences.
Args:
time (float): simulation time in seconds
time (float): The duration in seconds for which to step the simulation.
"""
# mju.remove_geom(self.model, self.data, "platform") # might be off here...
self.object.disable_eq_constraint(self.model, self.data)
mju.disable_equality_constraint(
self.model, self.data, *self.object.eq_constraints_to_disable
)
mju.disable_joint(self.model, self.data, *self.robot.joints)
start_time = self.data.time
while self.data.time - start_time < time:
Expand All @@ -67,17 +80,22 @@ def _pause_simulation(self, time: float) -> None:
self.render()

def _get_reward(self, terminated: bool) -> int:
"""Calculate reward by removing the platform and check if object falls to the ground.
0 reward: max_sim_time is not reached yet
-1 reward: object falls to the ground after removing the platform
1 reward: object is grasped successfully by the robot hand
"""
Calculates the reward based on the robot's success in grasping the object.
The reward is calculated after removing all fixed constraints and checking
if the object remains grasped by the robot.
Args:
terminated (bool): if episode is terminated
terminated (bool): Whether the episode has terminated.
Returns:
int: reward gotten per step
int: The reward for the current step. Possible values are:
- 0: The episode has not yet terminated.
- 1: The object is successfully grasped by the robot.
- -1: The object falls to the ground.
"""

if not terminated:
return 0
self._pause_simulation(1)
Expand All @@ -88,17 +106,42 @@ def _get_reward(self, terminated: bool) -> int:
return -1

def _is_terminated(self, sim_time: float) -> bool:
"""
Determines whether the episode has terminated based on the simulation time.
Args:
sim_time (float): The current simulation time.
Returns:
bool: True if the simulation time has exceeded the maximum allowed time, otherwise False.
"""
return sim_time >= self.max_sim_time

def _is_truncated(self) -> bool:
return False

def _get_info(self) -> Dict:
if self.viewer is not None:
"""
If the GUI viewer is running, this method will return a dictionary indicating that.
Returns:
Dict: A dictionary containing information about the environment.
"""
if self.gui and self.viewer is not None:
return {"is_viewer_running": self.viewer.is_running()}
return {}

def step(self, ctrl: ArrayLike) -> Tuple[NDArray, int, bool, bool, Dict]:
"""
Advances the simulation applying the given control input to the robot
Args:
ctrl (ArrayLike): The control input to be applied to the robot.
Returns:
Tuple[NDArray, int, bool, bool, Dict]: observation, reward, termination flag,
truncation flag, and an info.
"""
sim_time = self.data.time
self.robot.set_ctrl(self.model, self.data, ctrl)
mujoco.mj_step(self.model, self.data)
Expand Down
55 changes: 46 additions & 9 deletions deformable_gym/objects/mj_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,20 @@


class MJObject:
"""
A class representing a physical object in a MuJoCo simulation.
It provides methods to set the object's pose, get the center of mass, and manage equality constraints.
Attributes:
name (str): The name of the object, corresponding to its definition in the XML model.
model (mujoco.MjModel): The MuJoCo model object representing the object in the simulation.
eq_constraints (List[str]): A list of equality constraints associated with this object.
These constraints are typically used to enforce specific relationships between bodies,
such as keeping them at a fixed distance or maintaining an orientation.
eq_constraints_to_disable (List[Union[str, None]]): A list of equality constraints that are marked to be disabled for this object
in a specific occasion such as making the object free to move in the simulation for floating objects.
This list can be customized by subclasses to define which constraints should be ignored or temporarily disabled during the simulation.
"""

def __init__(self, name: str) -> None:
self.name = name
Expand All @@ -23,28 +37,51 @@ def eq_constraints_to_disable(self) -> List[Union[str, None]]:
return []

def _load_model(self, name: str) -> mujoco.MjModel:
manager = AssetManager()
return manager.load_asset(name)
"""
Load the MuJoCo model for the object.
This method uses the `AssetManager` to load the XML model file corresponding to the object.
Args:
name (str): The name of the object to load.
Returns:
mujoco.MjModel: The loaded MuJoCo model for the object.
"""
return AssetManager().load_asset(name)

def set_pose(
self,
model: mujoco.MjModel,
data: mujoco.MjData,
pose: Pose,
) -> None:
"""
Set the pose (position and orientation) of the object in the simulation.
This method updates the position and orientation of the object's body in the simulation
and recalculates the simulation state.
Args:
model (mujoco.MjModel): The MuJoCo model object containing the object's configuration.
data (mujoco.MjData): The MuJoCo data object containing the current simulation state.
pose (Pose): A Pose object containing the desired position and orientation for the object.
"""
model.body(self.name).pos[:] = pose.position
model.body(self.name).quat[:] = pose.orientation
mujoco.mj_forward(model, data)

def get_current_com(self, data: mujoco.MjData) -> NDArray:
return data.body(self.name).ipos
"""
Get the current center of mass (COM) position of the object.
def disable_eq_constraint(
self, model: mujoco.MjModel, data: mujoco.MjData, *name: str
) -> None:
if len(name) == 0:
name = self.eq_constraints_to_disable
mju.disable_equality_constraint(model, data, *name)
Args:
data (mujoco.MjData): The MuJoCo data object containing the current simulation state.
Returns:
NDArray: An array representing the position of the object's center of mass.
"""
return data.body(self.name).ipos


class InsoleFixed(MJObject):
Expand Down
Loading

0 comments on commit 13e97be

Please sign in to comment.