Skip to content

Commit

Permalink
Feature/urdf to mjcf (#309)
Browse files Browse the repository at this point in the history
* add .tar.gz

* urdf --> mjcf in same zip. mjcf --> urdf not functional

* Update store/app/crud/urdf.py

* cleanup

* small changes / eip

* fix tests

* another fix

* downgrade numpy

---------

Co-authored-by: Ben Bolte <[email protected]>
  • Loading branch information
nathanjzhao and codekansas authored Aug 22, 2024
1 parent 5ccb0c5 commit 78a3036
Show file tree
Hide file tree
Showing 10 changed files with 759 additions and 255 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ namespace_packages = false
module = [
"boto3.*",
"moto.*",
"mujoco.*",
"pybullet_utils.*",
]

ignore_missing_imports = true
Expand Down
2 changes: 1 addition & 1 deletion store/app/crud/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ async def _upload_mesh(

tmesh = trimesh.load(io.BytesIO(file_data), file_type=artifact_type)
if not isinstance(tmesh, trimesh.Trimesh):
raise BadArtifactError(f"Invalid mesh file: {name}")
raise BadArtifactError(f"Invalid mesh file: {name} ({type(tmesh)})")

out_file = io.BytesIO()
tmesh.export(out_file, file_type="obj")
Expand Down
2 changes: 1 addition & 1 deletion store/app/crud/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ async def _add_item(self, item: StoreBaseModel, unique_fields: list[str] | None
condition += " AND " + " AND ".join(f"attribute_not_exists({field})" for field in unique_fields)

# Log the item data before insertion for debugging purposes
logger.info(f"Inserting item into DynamoDB: {item_data}")
logger.info("Inserting item into DynamoDB: %s", item_data)

try:
await table.put_item(
Expand Down
57 changes: 42 additions & 15 deletions store/app/crud/urdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@
from store.app.crud.artifacts import ArtifactsCrud
from store.app.errors import BadArtifactError
from store.app.model import Artifact, CompressedArtifactType, Listing
from store.app.utils.convert import urdf_to_mjcf
from store.utils import save_xml

logger = logging.getLogger(__name__)


URDF_PACKAGE_NAME = "droid.tgz"


def iter_components(file: IO[bytes], compression_type: CompressedArtifactType) -> Iterator[tuple[str, IO[bytes]]]:
def iter_components(file: IO[bytes], compression_type: CompressedArtifactType) -> Iterator[tuple[str, bytes]]:
"""Iterates over the components of a tar file.
Args:
Expand All @@ -42,20 +44,20 @@ def iter_components(file: IO[bytes], compression_type: CompressedArtifactType) -
with zipfile.ZipFile(file) as zipf:
for name in zipf.namelist():
# Fix for MacOS.
if name.startswith("__MACOS"):
if any(p.startswith("__MACOS") or p.startswith(".") for p in Path(name).parts):
continue
# Ignore folders.
if name.endswith("/"):
continue
with zipf.open(name) as zipdata:
yield name, zipdata
yield name, zipdata.read()

case "tgz":
with tarfile.open(fileobj=file, mode="r:gz") as tar:
for member in tar.getmembers():
if member.isfile() and (tardata := tar.extractfile(member)) is not None:
with tardata:
yield member.name, tardata
yield member.name, tardata.read()

case _:
raise ValueError(f"Unknown compression type: {compression_type}")
Expand All @@ -71,6 +73,7 @@ async def set_urdf(
) -> Artifact:
# Unpacks the TAR file, getting meshes and URDFs.
urdf: tuple[str, ET.ElementTree] | None = None
mjcf: tuple[str, ET.ElementTree] | None = None
meshes: list[tuple[str, trimesh.Trimesh]] = []

compressed_data = await file.read()
Expand All @@ -79,40 +82,53 @@ async def set_urdf(

if suffix == "urdf":
if urdf is not None:
raise BadArtifactError("Multiple URDF files found in TAR.")
raise BadArtifactError(f"Multiple URDF files found in TAR: {urdf[0]} and {name}")
try:
urdf_tree = ET.parse(io.BytesIO(data.read()))
urdf_tree = ET.parse(io.BytesIO(data))
except Exception:
raise BadArtifactError("Invalid XML file")
urdf = name, urdf_tree

elif suffix in ("stl", "ply", "obj", "dae"):
try:
tmesh = trimesh.load(data, file_type=suffix)
assert isinstance(tmesh, trimesh.Trimesh)
except Exception:
raise BadArtifactError(f"Invalid mesh file: {name}")
tmesh = trimesh.load(io.BytesIO(data), file_type=suffix)
except Exception as e:
raise BadArtifactError(f"Not a valid mesh: {name} ({e})")
if not isinstance(tmesh, trimesh.Trimesh):
raise BadArtifactError(f"Invalid mesh file: {name} ({type(tmesh)})")
meshes.append((name, tmesh))

else:
raise BadArtifactError(f"Unknown file type: {name}")

if urdf is None:
raise BadArtifactError("No URDF file found in TAR.")
urdf_name, urdf_tree = urdf
raise BadArtifactError("No URDF file found.")

# Attempts to generate an MJCF file from the URDF.
try:
urdf_name, urdf_tree = urdf
mjcf_name = Path(urdf_name).with_suffix(".xml").as_posix()
mjcf_tree = urdf_to_mjcf(urdf_tree, meshes)
mjcf = mjcf_name, mjcf_tree
logger.info("Converting URDF to MJCF: %s -> %s", urdf_name, mjcf_name)
except Exception:
logger.exception("Failed to convert URDF to MJCF")

# Checks that all of the mesh files are referenced.
mesh_names = {Path(name) for name, _ in meshes}
mesh_references = {Path(name): False for name, _ in meshes}
for mesh in urdf_tree.iter("mesh"):
if (filename := mesh.get("filename")) is None:
raise BadArtifactError("Mesh element missing filename attribute.")
filepath = Path(filename).relative_to(".")
if filepath not in mesh_names:
raise BadArtifactError(f"Mesh referenced in URDF was not uploaded: {filepath}")
mesh_names.remove(filepath)
mesh_references[filepath] = True
mesh.set("filename", str(filepath.with_suffix(".obj")))
if mesh_names:
raise BadArtifactError(f"Mesh files uploaded were not referenced: {mesh_names}")

unreferenced_meshes = [name for name, referenced in mesh_references.items() if not referenced]
if unreferenced_meshes:
raise BadArtifactError(f"Mesh files uploaded were not referenced: {unreferenced_meshes}")

# Saves everything to a new TAR file, using OBJ files for meshes.
tgz_out_file = io.BytesIO()
Expand All @@ -125,15 +141,26 @@ async def set_urdf(
info.size = out_file.tell()
out_file.seek(0)
tar.addfile(info, out_file)

urdf_out_file = io.BytesIO()
save_xml(urdf_out_file, urdf_tree)
urdf_out_file.seek(0)
info = tarfile.TarInfo(urdf_name)
info.size = len(urdf_out_file.getbuffer())
tar.addfile(info, urdf_out_file)

if mjcf is not None:
mjcf_name, mjcf_tree = mjcf
mjcf_out_file = io.BytesIO()
save_xml(mjcf_out_file, mjcf_tree)
mjcf_out_file.seek(0)
info = tarfile.TarInfo(mjcf_name)
info.size = len(mjcf_out_file.getbuffer())
tar.addfile(info, mjcf_out_file)

# Saves the TAR file to S3.
tgz_out_file.seek(0)

return await self._upload_and_store(URDF_PACKAGE_NAME, tgz_out_file, listing, "tgz", description)

async def get_urdf(self, listing_id: str) -> Artifact | None:
Expand Down
9 changes: 5 additions & 4 deletions store/app/routers/urdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,13 @@ async def set_urdf(
)

# Gets the compression type from the file content type and filename.
compression_type = get_compression_type(file.content_type, file.filename)
if compression_type not in ("tgz", "zip"):
try:
compression_type = get_compression_type(file.content_type, file.filename)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The file must be a .tgz or .zip file",
)
detail="The file must be a .zip, .tgz, .tar.gz file",
) from e

# Checks that the listing is valid.
listing = await crud.get_listing(listing_id)
Expand Down
130 changes: 130 additions & 0 deletions store/app/utils/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""This module contains functions to convert between URDF and MJCF formats."""

import tempfile
import xml.etree.ElementTree as ET
from pathlib import Path

import trimesh
from pybullet_utils import bullet_client, urdfEditor

from store.app.utils.formats import mjcf


def urdf_to_mjcf(urdf_tree: ET.ElementTree, meshes: list[tuple[str, trimesh.Trimesh]]) -> ET.ElementTree:
"""Convert a URDF ElementTree to an MJCF ElementTree.
This function converts a URDF file to an MJCF file. It is intended to be
used when a URDF file is provided and an MJCF file is needed. The function
converts the URDF file to an MJCF file and returns the MJCF file as an
ElementTree.
Args:
urdf_tree: The URDF ElementTree to convert.
meshes: A list of tuples containing the mesh file name and the mesh
object.
Returns:
The MJCF ElementTree.
"""
robot_name = urdf_tree.getroot().get("name")
if robot_name is None:
raise ValueError("URDF tree does not contain a robot name.")

with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)

try:
# Save the URDF tree to a file
urdf_path = temp_dir_path / f"{robot_name}.urdf"
urdf_tree.write(urdf_path, encoding="utf-8")

# Save the mesh files
for mesh_name, mesh in meshes:
mesh_path = temp_dir_path / mesh_name
mesh_path.parent.mkdir(parents=True, exist_ok=True)
mesh.export(mesh_path)

# Loading the URDF file and adapting it to the MJCF format
mjcf_robot = mjcf.Robot(robot_name, temp_dir, mjcf.Compiler(angle="radian", meshdir="meshes"))
mjcf_robot.adapt_world()

# Save the MJCF file with the base name
mjcf_path = urdf_path.parent / f"{robot_name}.xml"
mjcf_robot.save(mjcf_path)

# Read the MJCF file back into an ElementTree
mjcf_tree = ET.parse(mjcf_path)
except Exception:
raise

return mjcf_tree


def mjcf_to_urdf(mjcf_tree: ET.ElementTree, meshes: list[tuple[str, trimesh.Trimesh]]) -> ET.ElementTree:
"""Convert an MJCF ElementTree to a URDF ElementTree with all parts combined.
This function assumes that the MJCF file contains a single robot with
multiple parts. It combines all parts into a single URDF file.
Note that this function is not particularly good - for example, it can
lose information about the types of joints between parts. It is intended
as a quick and dirty way to convert MJCF files to URDF files for use in
other tools.
Args:
mjcf_tree: The MJCF ElementTree to convert.
meshes: A list of tuples containing the mesh file name and the mesh
object.
Returns:
The URDF ElementTree with all parts combined.
"""
with tempfile.TemporaryDirectory() as temp_dir:
try:
temp_dir_path = Path(temp_dir)

# Save the MJCF tree to a file
mjcf_path = temp_dir_path / "robot_mjcf.xml"
mjcf_tree.write(mjcf_path, encoding="utf-8")

# Save the mesh files in the temporary directory
for mesh_name, mesh in meshes:
mesh_path = temp_dir_path / mesh_name
mesh_path.parent.mkdir(parents=True, exist_ok=True)
mesh.export(mesh_path)

# Set up the Bullet client and load the MJCF file
client = bullet_client.BulletClient()
objs = client.loadMJCF(str(mjcf_path), flags=client.URDF_USE_IMPLICIT_CYLINDER)

# Initialize a single URDF editor to store all parts
combined_urdf_editor = urdfEditor.UrdfEditor()

for obj in objs:
humanoid = obj # Get the current object
part_urdf_editor = urdfEditor.UrdfEditor()
part_urdf_editor.initializeFromBulletBody(humanoid, client._client)

# Add all links from the part URDF editor to the combined editor
for link in part_urdf_editor.urdfLinks:
if link not in combined_urdf_editor.urdfLinks:
combined_urdf_editor.urdfLinks.append(link)

# Add all joints from the part URDF editor to the combined editor
for joint in part_urdf_editor.urdfJoints:
if joint not in combined_urdf_editor.urdfJoints:
combined_urdf_editor.urdfJoints.append(joint)

# Set the output path for the combined URDF file
combined_urdf_path = temp_dir_path / "combined_robot.urdf"

# Save the combined URDF
combined_urdf_editor.saveUrdf(combined_urdf_path)

# Read the combined URDF file back into an ElementTree
urdf_tree = ET.parse(combined_urdf_path)

except Exception:
raise

return urdf_tree
Empty file.
Loading

0 comments on commit 78a3036

Please sign in to comment.