diff --git a/deformable_gym/helpers/mj_utils.py b/deformable_gym/helpers/mj_utils.py index 9546a07..f627339 100644 --- a/deformable_gym/helpers/mj_utils.py +++ b/deformable_gym/helpers/mj_utils.py @@ -1,9 +1,9 @@ -from dataclasses import dataclass, field -from typing import List, Tuple +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple import mujoco import numpy as np -from mujoco import MjData, MjModel, mjtJoint +from mujoco import MjData, MjModel, MjSpec, mjtJoint from numpy.typing import ArrayLike, NDArray # fmt: off @@ -278,6 +278,38 @@ def rotate_quat_by_euler(quat: ArrayLike, euler: ArrayLike) -> NDArray: return result +# -------------------------------- SPEC UTILS --------------------------------# +def make_spec_from_file(path: str) -> MjSpec: + spec = mujoco.MjSpec() + spec.from_file(path) + return spec + + +def make_spec_from_string(xml: str) -> MjSpec: + spec = mujoco.MjSpec() + spec.from_string(xml) + return spec + + +def spec2model(spec: MjSpec) -> MjModel: + return spec.compile() + + +def add_geom2body( + spec: MjSpec, body_name: str, geom_attr: Dict[str, Any] +) -> None: + body = spec.find_body(body_name) + geom = body.add_geom() + for attr_name, value in geom_attr.items(): + setattr(geom, attr_name, value) + + +def add_flex(spec: MjSpec, flex_attr: Dict[str, Any]) -> None: + flex = spec.add_flex() + for attr_name, value in flex_attr.items(): + setattr(flex, attr_name, value) + + # -------------------------------- OTHER UTILS --------------------------------# def id2name(model: MjModel, id: int, t: str = "body") -> str: assert (