diff --git a/store/app/crud/artifacts.py b/store/app/crud/artifacts.py index 66721e01..6cbe63df 100644 --- a/store/app/crud/artifacts.py +++ b/store/app/crud/artifacts.py @@ -4,11 +4,13 @@ import io import logging from typing import Any, BinaryIO, Literal +from xml.etree import ElementTree as ET from PIL import Image +from stl import Mode as STLMode, mesh as stlmesh from store.app.crud.base import BaseCrud, ItemNotFoundError -from store.app.errors import NotAuthorizedError +from store.app.errors import BadArtifactError, NotAuthorizedError from store.app.model import ( Artifact, ArtifactSize, @@ -19,6 +21,7 @@ get_content_type, ) from store.settings import settings +from store.utils import save_xml logger = logging.getLogger(__name__) @@ -28,9 +31,7 @@ class ArtifactsCrud(BaseCrud): def get_gsis(cls) -> set[str]: return super().get_gsis().union({"user_id", "listing_id"}) - def _crop_image(self, image: Image.Image, size: tuple[int, int]) -> io.BytesIO: - image_bytes = io.BytesIO() - + async def _crop_image(self, image: Image.Image, size: tuple[int, int]) -> io.BytesIO: # Simply squashes the image to the desired size. # image_resized = image.resize(size, resample=Image.Resampling.BICUBIC) # Finds a bounding box of the image and crops it to the desired size. @@ -47,18 +48,19 @@ def _crop_image(self, image: Image.Image, size: tuple[int, int]) -> io.BytesIO: upper = (image.height - new_height) // 2 right = left + new_width lower = upper + new_height - image_resized = image.crop((left, upper, right, lower)) + image = image.crop((left, upper, right, lower)) # Resize the image to the desired size. - image_resized = image_resized.resize(size, resample=Image.Resampling.BICUBIC) + image = image.resize(size, resample=Image.Resampling.BICUBIC) # Save the image to a byte stream. - image_resized.save(image_bytes, format="PNG", optimize=True, quality=settings.image.quality) + image_bytes = io.BytesIO() + image.save(image_bytes, format="PNG", optimize=True, quality=settings.artifact.quality) image_bytes.seek(0) return image_bytes async def _upload_cropped_image(self, image: Image.Image, name: str, image_id: str, size: ArtifactSize) -> None: - image_bytes = self._crop_image(image, SizeMapping[size]) + image_bytes = await self._crop_image(image, SizeMapping[size]) filename = get_artifact_name(image_id, "image", size) await self._upload_to_s3(image_bytes, name, filename, "image/png") @@ -100,7 +102,39 @@ async def _upload_image( async def get_raw_artifact(self, artifact_id: str) -> Artifact | None: return await self._get_item(artifact_id, Artifact) - async def _upload_raw_artifact( + async def _upload_stl( + self, + name: str, + file: io.BytesIO | BinaryIO, + listing: Listing, + user_id: str, + description: str | None = None, + ) -> Artifact: + if listing.user_id != user_id: + raise NotAuthorizedError("User does not have permission to upload artifacts to this listing") + + # Converts the mesh to a binary STL file. + mesh = stlmesh.Mesh.from_file(None, calculate_normals=True, fh=file) + out_file = io.BytesIO() + mesh.save(name, fh=out_file, mode=STLMode.BINARY) + out_file.seek(0) + + # Saves the artifact to S3. + content_type = get_content_type("stl") + artifact = Artifact.create( + user_id=user_id, + listing_id=listing.id, + name=name, + artifact_type="stl", + description=description, + ) + await asyncio.gather( + self._upload_to_s3(out_file, name, get_artifact_name(artifact.id, "stl"), content_type), + self._add_item(artifact), + ) + return artifact + + async def _upload_xml( self, name: str, file: io.BytesIO | BinaryIO, @@ -111,6 +145,21 @@ async def _upload_raw_artifact( ) -> Artifact: if listing.user_id != user_id: raise NotAuthorizedError("User does not have permission to upload artifacts to this listing") + + # Standardizes the XML file. + try: + tree = ET.parse(file) + except Exception: + raise BadArtifactError("Invalid XML file") + + # TODO: Remap the STL or OBJ file paths. + + # Converts to byte stream. + out_file = io.BytesIO() + save_xml(out_file, tree) + out_file.seek(0) + + # Saves the artifact to S3. content_type = get_content_type(artifact_type) artifact = Artifact.create( user_id=user_id, @@ -137,8 +186,12 @@ async def upload_artifact( match artifact_type: case "image": return await self._upload_image(name, file, listing, user_id, description) + case "stl": + return await self._upload_stl(name, file, listing, user_id, description) + case "urdf" | "mjcf": + return await self._upload_xml(name, file, listing, user_id, artifact_type, description) case _: - return await self._upload_raw_artifact(name, file, listing, user_id, artifact_type, description) + raise BadArtifactError(f"Invalid artifact type: {artifact_type}") async def _remove_image(self, artifact: Artifact, user_id: str) -> None: if artifact.user_id != user_id: @@ -151,7 +204,7 @@ async def _remove_image(self, artifact: Artifact, user_id: str) -> None: async def _remove_raw_artifact( self, artifact: Artifact, - artifact_type: Literal["urdf", "mjcf"], + artifact_type: Literal["urdf", "mjcf", "stl"], user_id: str, ) -> None: if artifact.user_id != user_id: diff --git a/store/app/errors.py b/store/app/errors.py index 2195f9cd..5d4a7d6d 100644 --- a/store/app/errors.py +++ b/store/app/errors.py @@ -11,3 +11,6 @@ class ItemNotFoundError(ValueError): ... class InternalError(RuntimeError): ... + + +class BadArtifactError(Exception): ... diff --git a/store/app/model.py b/store/app/model.py index a405b358..3e8a4df9 100644 --- a/store/app/model.py +++ b/store/app/model.py @@ -88,23 +88,31 @@ def create( ArtifactSize = Literal["small", "large"] -ArtifactType = Literal["image", "urdf", "mjcf"] +ArtifactType = Literal["image", "urdf", "mjcf", "stl"] UPLOAD_CONTENT_TYPE_OPTIONS: dict[ArtifactType, set[str]] = { + # Image "image": {"image/png", "image/jpeg", "image/jpg"}, - "urdf": {"application/gzip", "application/x-gzip"}, - "mjcf": {"application/gzip", "application/x-gzip"}, + # XML + "urdf": {"application/xml"}, + "mjcf": {"application/xml"}, + # Binary or text + "stl": {"application/octet-stream", "text/xml"}, } DOWNLOAD_CONTENT_TYPE: dict[ArtifactType, str] = { + # Image "image": "image/png", - "urdf": "application/gzip", - "mjcf": "application/gzip", + # XML + "urdf": "application/xml", + "mjcf": "application/xml", + # Binary + "stl": "application/octet-stream", } SizeMapping: dict[ArtifactSize, tuple[int, int]] = { - "large": settings.image.large_image_size, - "small": settings.image.small_image_size, + "large": settings.artifact.large_image_size, + "small": settings.artifact.small_image_size, } @@ -116,9 +124,11 @@ def get_artifact_name(id: str, artifact_type: ArtifactType, size: ArtifactSize = height, width = SizeMapping[size] return f"{id}_{size}_{height}x{width}.png" case "urdf": - return f"{id}.tar.gz" + return f"{id}.urdf" case "mjcf": - return f"{id}.tar.gz" + return f"{id}.xml" + case "stl": + return f"{id}.stl" case _: raise ValueError(f"Unknown artifact type: {artifact_type}") diff --git a/store/app/routers/artifacts.py b/store/app/routers/artifacts.py index a8c197d9..e56a4251 100644 --- a/store/app/routers/artifacts.py +++ b/store/app/routers/artifacts.py @@ -72,15 +72,15 @@ def validate_file(file: UploadFile, artifact_type: ArtifactType) -> str: status_code=status.HTTP_400_BAD_REQUEST, detail="Artifact size was not provided", ) - if file.size < settings.image.min_bytes: + if file.size < settings.artifact.min_bytes: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Artifact size is too small; file size {file.size} is less than {settings.image.min_bytes} bytes", + detail=f"Artifact size is too small; {file.size} is less than {settings.artifact.min_bytes} bytes", ) - if file.size > settings.image.max_bytes: + if file.size > settings.artifact.max_bytes: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail=f"Artifact size is too large; file size {file.size} is less than {settings.image.max_bytes} bytes", + detail=f"Artifact size is too large; {file.size} is greater than {settings.artifact.max_bytes} bytes", ) if (content_type := file.content_type) is None: raise HTTPException( diff --git a/store/requirements.txt b/store/requirements.txt index 3cd127e6..43b6b6fe 100644 --- a/store/requirements.txt +++ b/store/requirements.txt @@ -23,5 +23,8 @@ httpx # Deployment dependencies. uvicorn[standard] +# Processing dependencies. +numpy-stl + # Types types-aioboto3[dynamodb, s3] diff --git a/store/settings/environment.py b/store/settings/environment.py index 9005e3c5..091ee1f1 100644 --- a/store/settings/environment.py +++ b/store/settings/environment.py @@ -69,7 +69,7 @@ class EnvironmentSettings: user: UserSettings = field(default_factory=UserSettings) crypto: CryptoSettings = field(default_factory=CryptoSettings) email: EmailSettings = field(default_factory=EmailSettings) - image: ArtifactSettings = field(default_factory=ArtifactSettings) + artifact: ArtifactSettings = field(default_factory=ArtifactSettings) s3: S3Settings = field(default_factory=S3Settings) dynamo: DynamoSettings = field(default_factory=DynamoSettings) site: SiteSettings = field(default_factory=SiteSettings) diff --git a/store/utils.py b/store/utils.py index 47c63f2d..7dc22848 100644 --- a/store/utils.py +++ b/store/utils.py @@ -3,9 +3,12 @@ import datetime import functools import hashlib +import io import uuid from collections import OrderedDict +from pathlib import Path from typing import Awaitable, Callable, Generic, Hashable, ParamSpec, TypeVar, overload +from xml.etree import ElementTree as ET Tk = TypeVar("Tk", bound=Hashable) Tv = TypeVar("Tv") @@ -149,3 +152,26 @@ def new_uuid() -> str: SHA-256 hash of a UUID4 value. """ return hashlib.sha256(str(uuid.uuid4()).encode()).hexdigest()[:16] + + +def save_xml(path: str | Path | io.BytesIO, tree: ET.ElementTree) -> None: + root = tree.getroot() + + def indent(elem: ET.Element, level: int = 0) -> ET.Element: + i = "\n" + level * " " + if len(elem): + if not elem.text or not elem.text.strip(): + elem.text = i + " " + if not elem.tail or not elem.tail.strip(): + elem.tail = i + for e in elem: + indent(e, level + 1) + if not elem.tail or not elem.tail.strip(): + elem.tail = i + else: # noqa: PLR5501 + if level and (not elem.tail or not elem.tail.strip()): + elem.tail = i + return elem + + indent(root) + tree.write(path, encoding="utf-8", xml_declaration=True, method="xml") diff --git a/tests/assets/sample.urdf b/tests/assets/sample.urdf new file mode 100644 index 00000000..7f9a4aed --- /dev/null +++ b/tests/assets/sample.urdf @@ -0,0 +1,266 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/assets/teapot.stl b/tests/assets/teapot.stl new file mode 100644 index 00000000..d2259187 Binary files /dev/null and b/tests/assets/teapot.stl differ diff --git a/tests/test_listings.py b/tests/test_listings.py index 914358a2..55961f8b 100644 --- a/tests/test_listings.py +++ b/tests/test_listings.py @@ -46,6 +46,30 @@ async def test_listings(app_client: AsyncClient, tmpdir: Path) -> None: data = response.json() assert data["artifact"]["artifact_id"] is not None + # Uploads a URDF. + urdf_path = Path(__file__).parent / "assets" / "sample.urdf" + data_json = json.dumps({"artifact_type": "urdf", "listing_id": listing_id}) + response = await app_client.post( + "/artifacts/upload", + headers=auth_headers, + files={"file": ("box.urdf", open(urdf_path, "rb"), "application/xml"), "metadata": (None, data_json)}, + ) + assert response.status_code == status.HTTP_200_OK, response.json() + data = response.json() + assert data["artifact"]["artifact_id"] is not None + + # Uploads an STL. + stl_path = Path(__file__).parent / "assets" / "teapot.stl" + data_json = json.dumps({"artifact_type": "stl", "listing_id": listing_id}) + response = await app_client.post( + "/artifacts/upload", + headers=auth_headers, + files={"file": ("teapot.stl", open(stl_path, "rb"), "application/octet-stream"), "metadata": (None, data_json)}, + ) + assert response.status_code == status.HTTP_200_OK, response.json() + data = response.json() + assert data["artifact"]["artifact_id"] is not None + # Searches for listings. response = await app_client.get( "/listings/search",