Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/urdf to mjcf #309

Merged
merged 8 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ namespace_packages = false
module = [
"boto3.*",
"moto.*",
"mujoco.*",
"pybullet_utils.*",
]

ignore_missing_imports = true
Expand Down
2 changes: 1 addition & 1 deletion store/app/crud/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ async def _upload_mesh(

tmesh = trimesh.load(io.BytesIO(file_data), file_type=artifact_type)
if not isinstance(tmesh, trimesh.Trimesh):
raise BadArtifactError(f"Invalid mesh file: {name}")
raise BadArtifactError(f"Invalid mesh file: {name} ({type(tmesh)})")

out_file = io.BytesIO()
tmesh.export(out_file, file_type="obj")
Expand Down
2 changes: 1 addition & 1 deletion store/app/crud/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ async def _add_item(self, item: StoreBaseModel, unique_fields: list[str] | None
condition += " AND " + " AND ".join(f"attribute_not_exists({field})" for field in unique_fields)

# Log the item data before insertion for debugging purposes
logger.info(f"Inserting item into DynamoDB: {item_data}")
logger.info("Inserting item into DynamoDB: %s", item_data)

try:
await table.put_item(
Expand Down
57 changes: 42 additions & 15 deletions store/app/crud/urdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@
from store.app.crud.artifacts import ArtifactsCrud
from store.app.errors import BadArtifactError
from store.app.model import Artifact, CompressedArtifactType, Listing
from store.app.utils.convert import urdf_to_mjcf
from store.utils import save_xml

logger = logging.getLogger(__name__)


URDF_PACKAGE_NAME = "droid.tgz"


def iter_components(file: IO[bytes], compression_type: CompressedArtifactType) -> Iterator[tuple[str, IO[bytes]]]:
def iter_components(file: IO[bytes], compression_type: CompressedArtifactType) -> Iterator[tuple[str, bytes]]:
"""Iterates over the components of a tar file.

Args:
Expand All @@ -42,20 +44,20 @@ def iter_components(file: IO[bytes], compression_type: CompressedArtifactType) -
with zipfile.ZipFile(file) as zipf:
for name in zipf.namelist():
# Fix for MacOS.
if name.startswith("__MACOS"):
if any(p.startswith("__MACOS") or p.startswith(".") for p in Path(name).parts):
continue
# Ignore folders.
if name.endswith("/"):
continue
with zipf.open(name) as zipdata:
yield name, zipdata
yield name, zipdata.read()

case "tgz":
with tarfile.open(fileobj=file, mode="r:gz") as tar:
for member in tar.getmembers():
if member.isfile() and (tardata := tar.extractfile(member)) is not None:
with tardata:
yield member.name, tardata
yield member.name, tardata.read()

case _:
raise ValueError(f"Unknown compression type: {compression_type}")
Expand All @@ -71,6 +73,7 @@ async def set_urdf(
) -> Artifact:
# Unpacks the TAR file, getting meshes and URDFs.
urdf: tuple[str, ET.ElementTree] | None = None
mjcf: tuple[str, ET.ElementTree] | None = None
meshes: list[tuple[str, trimesh.Trimesh]] = []

compressed_data = await file.read()
Expand All @@ -79,40 +82,53 @@ async def set_urdf(

if suffix == "urdf":
if urdf is not None:
raise BadArtifactError("Multiple URDF files found in TAR.")
raise BadArtifactError(f"Multiple URDF files found in TAR: {urdf[0]} and {name}")
try:
urdf_tree = ET.parse(io.BytesIO(data.read()))
urdf_tree = ET.parse(io.BytesIO(data))
except Exception:
raise BadArtifactError("Invalid XML file")
urdf = name, urdf_tree

elif suffix in ("stl", "ply", "obj", "dae"):
try:
tmesh = trimesh.load(data, file_type=suffix)
assert isinstance(tmesh, trimesh.Trimesh)
except Exception:
raise BadArtifactError(f"Invalid mesh file: {name}")
tmesh = trimesh.load(io.BytesIO(data), file_type=suffix)
except Exception as e:
raise BadArtifactError(f"Not a valid mesh: {name} ({e})")
if not isinstance(tmesh, trimesh.Trimesh):
raise BadArtifactError(f"Invalid mesh file: {name} ({type(tmesh)})")
meshes.append((name, tmesh))

else:
raise BadArtifactError(f"Unknown file type: {name}")

if urdf is None:
raise BadArtifactError("No URDF file found in TAR.")
urdf_name, urdf_tree = urdf
raise BadArtifactError("No URDF file found.")

# Attempts to generate an MJCF file from the URDF.
try:
urdf_name, urdf_tree = urdf
mjcf_name = Path(urdf_name).with_suffix(".xml").as_posix()
mjcf_tree = urdf_to_mjcf(urdf_tree, meshes)
mjcf = mjcf_name, mjcf_tree
logger.info("Converting URDF to MJCF: %s -> %s", urdf_name, mjcf_name)
except Exception:
logger.exception("Failed to convert URDF to MJCF")

# Checks that all of the mesh files are referenced.
mesh_names = {Path(name) for name, _ in meshes}
mesh_references = {Path(name): False for name, _ in meshes}
for mesh in urdf_tree.iter("mesh"):
if (filename := mesh.get("filename")) is None:
raise BadArtifactError("Mesh element missing filename attribute.")
filepath = Path(filename).relative_to(".")
if filepath not in mesh_names:
raise BadArtifactError(f"Mesh referenced in URDF was not uploaded: {filepath}")
mesh_names.remove(filepath)
mesh_references[filepath] = True
mesh.set("filename", str(filepath.with_suffix(".obj")))
if mesh_names:
raise BadArtifactError(f"Mesh files uploaded were not referenced: {mesh_names}")

unreferenced_meshes = [name for name, referenced in mesh_references.items() if not referenced]
if unreferenced_meshes:
raise BadArtifactError(f"Mesh files uploaded were not referenced: {unreferenced_meshes}")

# Saves everything to a new TAR file, using OBJ files for meshes.
tgz_out_file = io.BytesIO()
Expand All @@ -125,15 +141,26 @@ async def set_urdf(
info.size = out_file.tell()
out_file.seek(0)
tar.addfile(info, out_file)

urdf_out_file = io.BytesIO()
save_xml(urdf_out_file, urdf_tree)
urdf_out_file.seek(0)
info = tarfile.TarInfo(urdf_name)
info.size = len(urdf_out_file.getbuffer())
tar.addfile(info, urdf_out_file)

if mjcf is not None:
mjcf_name, mjcf_tree = mjcf
mjcf_out_file = io.BytesIO()
save_xml(mjcf_out_file, mjcf_tree)
mjcf_out_file.seek(0)
info = tarfile.TarInfo(mjcf_name)
info.size = len(mjcf_out_file.getbuffer())
tar.addfile(info, mjcf_out_file)

# Saves the TAR file to S3.
tgz_out_file.seek(0)

return await self._upload_and_store(URDF_PACKAGE_NAME, tgz_out_file, listing, "tgz", description)

async def get_urdf(self, listing_id: str) -> Artifact | None:
Expand Down
9 changes: 5 additions & 4 deletions store/app/routers/urdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,13 @@ async def set_urdf(
)

# Gets the compression type from the file content type and filename.
compression_type = get_compression_type(file.content_type, file.filename)
if compression_type not in ("tgz", "zip"):
try:
compression_type = get_compression_type(file.content_type, file.filename)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The file must be a .tgz or .zip file",
)
detail="The file must be a .zip, .tgz, .tar.gz file",
) from e

# Checks that the listing is valid.
listing = await crud.get_listing(listing_id)
Expand Down
130 changes: 130 additions & 0 deletions store/app/utils/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""This module contains functions to convert between URDF and MJCF formats."""

import tempfile
import xml.etree.ElementTree as ET
from pathlib import Path

import trimesh
from pybullet_utils import bullet_client, urdfEditor

from store.app.utils.formats import mjcf


def urdf_to_mjcf(urdf_tree: ET.ElementTree, meshes: list[tuple[str, trimesh.Trimesh]]) -> ET.ElementTree:
"""Convert a URDF ElementTree to an MJCF ElementTree.

This function converts a URDF file to an MJCF file. It is intended to be
used when a URDF file is provided and an MJCF file is needed. The function
converts the URDF file to an MJCF file and returns the MJCF file as an
ElementTree.

Args:
urdf_tree: The URDF ElementTree to convert.
meshes: A list of tuples containing the mesh file name and the mesh
object.

Returns:
The MJCF ElementTree.
"""
robot_name = urdf_tree.getroot().get("name")
if robot_name is None:
raise ValueError("URDF tree does not contain a robot name.")

with tempfile.TemporaryDirectory() as temp_dir:
temp_dir_path = Path(temp_dir)

try:
# Save the URDF tree to a file
urdf_path = temp_dir_path / f"{robot_name}.urdf"
urdf_tree.write(urdf_path, encoding="utf-8")

# Save the mesh files
for mesh_name, mesh in meshes:
mesh_path = temp_dir_path / mesh_name
mesh_path.parent.mkdir(parents=True, exist_ok=True)
mesh.export(mesh_path)

# Loading the URDF file and adapting it to the MJCF format
mjcf_robot = mjcf.Robot(robot_name, temp_dir, mjcf.Compiler(angle="radian", meshdir="meshes"))
mjcf_robot.adapt_world()

# Save the MJCF file with the base name
mjcf_path = urdf_path.parent / f"{robot_name}.xml"
mjcf_robot.save(mjcf_path)

# Read the MJCF file back into an ElementTree
mjcf_tree = ET.parse(mjcf_path)
except Exception:
raise

return mjcf_tree


def mjcf_to_urdf(mjcf_tree: ET.ElementTree, meshes: list[tuple[str, trimesh.Trimesh]]) -> ET.ElementTree:
"""Convert an MJCF ElementTree to a URDF ElementTree with all parts combined.

This function assumes that the MJCF file contains a single robot with
multiple parts. It combines all parts into a single URDF file.

Note that this function is not particularly good - for example, it can
lose information about the types of joints between parts. It is intended
as a quick and dirty way to convert MJCF files to URDF files for use in
other tools.

Args:
mjcf_tree: The MJCF ElementTree to convert.
meshes: A list of tuples containing the mesh file name and the mesh
object.

Returns:
The URDF ElementTree with all parts combined.
"""
with tempfile.TemporaryDirectory() as temp_dir:
try:
temp_dir_path = Path(temp_dir)

# Save the MJCF tree to a file
mjcf_path = temp_dir_path / "robot_mjcf.xml"
mjcf_tree.write(mjcf_path, encoding="utf-8")

# Save the mesh files in the temporary directory
for mesh_name, mesh in meshes:
mesh_path = temp_dir_path / mesh_name
mesh_path.parent.mkdir(parents=True, exist_ok=True)
mesh.export(mesh_path)

# Set up the Bullet client and load the MJCF file
client = bullet_client.BulletClient()
objs = client.loadMJCF(str(mjcf_path), flags=client.URDF_USE_IMPLICIT_CYLINDER)

# Initialize a single URDF editor to store all parts
combined_urdf_editor = urdfEditor.UrdfEditor()

for obj in objs:
humanoid = obj # Get the current object
part_urdf_editor = urdfEditor.UrdfEditor()
part_urdf_editor.initializeFromBulletBody(humanoid, client._client)

# Add all links from the part URDF editor to the combined editor
for link in part_urdf_editor.urdfLinks:
if link not in combined_urdf_editor.urdfLinks:
combined_urdf_editor.urdfLinks.append(link)

# Add all joints from the part URDF editor to the combined editor
for joint in part_urdf_editor.urdfJoints:
if joint not in combined_urdf_editor.urdfJoints:
combined_urdf_editor.urdfJoints.append(joint)

# Set the output path for the combined URDF file
combined_urdf_path = temp_dir_path / "combined_robot.urdf"

# Save the combined URDF
combined_urdf_editor.saveUrdf(combined_urdf_path)

# Read the combined URDF file back into an ElementTree
urdf_tree = ET.parse(combined_urdf_path)

except Exception:
raise

return urdf_tree
Empty file.
Loading
Loading