From 477e35fa5aba43a879a1bcda946d3c6b60818b5a Mon Sep 17 00:00:00 2001 From: Ben Bolte Date: Sat, 17 Aug 2024 02:21:44 -0700 Subject: [PATCH] upload urdfs with stls (#296) * upload urdfs with stls * format * tests --- .../components/listing/ListingDescription.tsx | 6 +- .../src/components/listing/MeshRenderer.tsx | 14 ++-- frontend/src/components/nav/Sidebar.tsx | 3 - frontend/src/gen/api.ts | 11 +-- store/app/crud/artifacts.py | 36 +++++----- store/app/main.py | 9 +++ store/app/model.py | 69 +++++++++++++------ store/app/routers/artifacts.py | 28 ++++++-- store/app/routers/listings.py | 2 +- tests/assets/sample.urdf | 2 +- tests/test_images.py | 5 +- tests/test_listings.py | 16 ++++- 12 files changed, 130 insertions(+), 71 deletions(-) diff --git a/frontend/src/components/listing/ListingDescription.tsx b/frontend/src/components/listing/ListingDescription.tsx index 94510ceb..335cd96a 100644 --- a/frontend/src/components/listing/ListingDescription.tsx +++ b/frontend/src/components/listing/ListingDescription.tsx @@ -61,13 +61,13 @@ export const RenderDescription = ({ description }: RenderDescriptionProps) => {

{children}

), img: ({ src, alt }) => ( -
src && setImageModal([src, alt ?? ""])} > {alt} - {alt &&

{alt}

} -
+ {alt && {alt}} + ), }} > diff --git a/frontend/src/components/listing/MeshRenderer.tsx b/frontend/src/components/listing/MeshRenderer.tsx index 268b7766..f8fdc446 100644 --- a/frontend/src/components/listing/MeshRenderer.tsx +++ b/frontend/src/components/listing/MeshRenderer.tsx @@ -47,15 +47,9 @@ interface UrdfModelProps { const UrdfModel = ({ url, meshType }: UrdfModelProps) => { const ref = useRef(); - const [robot, setRobot] = useState(); + const geom = useLoader(URDFLoader, url); - const loader = new URDFLoader(); - - loader.load(url, (robot) => { - setRobot(robot); - }); - - return robot ? ( + return ( { > { - ) : null; + ); }; interface StlModelProps { diff --git a/frontend/src/components/nav/Sidebar.tsx b/frontend/src/components/nav/Sidebar.tsx index 1dc3443c..1dd25405 100644 --- a/frontend/src/components/nav/Sidebar.tsx +++ b/frontend/src/components/nav/Sidebar.tsx @@ -4,10 +4,7 @@ import { FaDoorOpen, FaHome, FaKey, - FaLock, FaPen, - FaQuestion, - FaScroll, FaTimes, FaUserCircle, } from "react-icons/fa"; diff --git a/frontend/src/gen/api.ts b/frontend/src/gen/api.ts index 0238176e..636f2f9c 100644 --- a/frontend/src/gen/api.ts +++ b/frontend/src/gen/api.ts @@ -390,7 +390,7 @@ export interface paths { patch?: never; trace?: never; }; - "/artifacts/url/{artifact_type}/{artifact_id}": { + "/artifacts/url/{artifact_type}/{listing_id}/{name}": { parameters: { query?: never; header?: never; @@ -398,7 +398,7 @@ export interface paths { cookie?: never; }; /** Artifact Url */ - get: operations["artifact_url_artifacts_url__artifact_type___artifact_id__get"]; + get: operations["artifact_url_artifacts_url__artifact_type___listing_id___name__get"]; put?: never; post?: never; delete?: never; @@ -630,6 +630,8 @@ export interface components { ListArtifactsItem: { /** Artifact Id */ artifact_id: string; + /** Listing Id */ + listing_id: string; /** Name */ name: string; /** @@ -1491,7 +1493,7 @@ export interface operations { }; }; }; - artifact_url_artifacts_url__artifact_type___artifact_id__get: { + artifact_url_artifacts_url__artifact_type___listing_id___name__get: { parameters: { query?: { size?: "small" | "large"; @@ -1499,7 +1501,8 @@ export interface operations { header?: never; path: { artifact_type: "image" | "urdf" | "mjcf" | "stl"; - artifact_id: string; + listing_id: string; + name: string; }; cookie?: never; }; diff --git a/store/app/crud/artifacts.py b/store/app/crud/artifacts.py index beecae30..a25c3dfe 100644 --- a/store/app/crud/artifacts.py +++ b/store/app/crud/artifacts.py @@ -29,7 +29,7 @@ class ArtifactsCrud(BaseCrud): @classmethod def get_gsis(cls) -> set[str]: - return super().get_gsis().union({"user_id", "listing_id"}) + return super().get_gsis().union({"user_id", "listing_id", "name"}) async def _crop_image(self, image: Image.Image, size: tuple[int, int]) -> io.BytesIO: # Simply squashes the image to the desired size. @@ -59,10 +59,10 @@ async def _crop_image(self, image: Image.Image, size: tuple[int, int]) -> io.Byt image_bytes.seek(0) return image_bytes - async def _upload_cropped_image(self, image: Image.Image, name: str, image_id: str, size: ArtifactSize) -> None: + async def _upload_cropped_image(self, image: Image.Image, artifact: Artifact, size: ArtifactSize) -> None: image_bytes = await self._crop_image(image, SizeMapping[size]) - filename = get_artifact_name(image_id, "image", size) - await self._upload_to_s3(image_bytes, name, filename, "image/png") + filename = get_artifact_name(artifact=artifact, size=size) + await self._upload_to_s3(image_bytes, artifact.name, filename, "image/png") async def _upload_image( self, @@ -86,15 +86,7 @@ async def _upload_image( image = Image.open(file) await asyncio.gather( - *( - self._upload_cropped_image( - image=image, - name=name, - image_id=artifact.id, - size=size, - ) - for size in SizeMapping.keys() - ), + *(self._upload_cropped_image(image=image, artifact=artifact, size=size) for size in SizeMapping.keys()), self._add_item(artifact), ) return artifact @@ -129,7 +121,7 @@ async def _upload_stl( description=description, ) await asyncio.gather( - self._upload_to_s3(out_file, name, get_artifact_name(artifact.id, "stl"), content_type), + self._upload_to_s3(out_file, name, get_artifact_name(artifact=artifact), content_type), self._add_item(artifact), ) return artifact @@ -169,7 +161,7 @@ async def _upload_xml( description=description, ) await asyncio.gather( - self._upload_to_s3(out_file, name, get_artifact_name(artifact.id, artifact_type), content_type), + self._upload_to_s3(out_file, name, get_artifact_name(artifact=artifact), content_type), self._add_item(artifact), ) return artifact @@ -183,6 +175,10 @@ async def upload_artifact( artifact_type: ArtifactType, description: str | None = None, ) -> Artifact: + # Validates that the name is unique. + if await self.has_artifact_named(name): + raise BadArtifactError("An artifact with this name already exists") + match artifact_type: case "image": return await self._upload_image(name, file, listing, user_id, description) @@ -197,20 +193,19 @@ async def _remove_image(self, artifact: Artifact, user_id: str) -> None: if artifact.user_id != user_id: raise NotAuthorizedError("User does not have permission to delete this image") await asyncio.gather( - *(self._delete_from_s3(get_artifact_name(artifact.id, "image", size)) for size in SizeMapping.keys()), + *(self._delete_from_s3(get_artifact_name(artifact=artifact, size=size)) for size in SizeMapping.keys()), self._delete_item(artifact), ) async def _remove_raw_artifact( self, artifact: Artifact, - artifact_type: Literal["urdf", "mjcf", "stl"], user_id: str, ) -> None: if artifact.user_id != user_id: raise NotAuthorizedError("User does not have permission to delete this artifact") await asyncio.gather( - self._delete_from_s3(get_artifact_name(artifact.id, artifact_type)), + self._delete_from_s3(get_artifact_name(artifact=artifact)), self._delete_item(artifact), ) @@ -219,7 +214,7 @@ async def remove_artifact(self, artifact: Artifact, user_id: str) -> None: case "image": await self._remove_image(artifact, user_id) case _: - await self._remove_raw_artifact(artifact, artifact.artifact_type, user_id) + await self._remove_raw_artifact(artifact, user_id) async def get_listing_artifacts(self, listing_id: str) -> list[Artifact]: artifacts = await self._get_items_from_secondary_index("listing_id", listing_id, Artifact) @@ -229,6 +224,9 @@ async def get_listings_artifacts(self, listing_ids: list[str]) -> list[list[Arti artifact_chunks = await self._get_items_from_secondary_index_batch("listing_id", listing_ids, Artifact) return [sorted(artifacts, key=lambda a: a.timestamp) for artifacts in artifact_chunks] + async def has_artifact_named(self, filename: str) -> bool: + return len(await self._get_items_from_secondary_index("name", filename, Artifact)) > 0 + async def edit_artifact( self, artifact_id: str, diff --git a/store/app/main.py b/store/app/main.py index 5e7b3c82..9dd1096e 100644 --- a/store/app/main.py +++ b/store/app/main.py @@ -11,6 +11,7 @@ from store.app.db import create_tables from store.app.errors import ( + BadArtifactError, InternalError, ItemNotFoundError, NotAuthenticatedError, @@ -91,6 +92,14 @@ async def not_authorized_exception_handler(request: Request, exc: NotAuthorizedE ) +@app.exception_handler(BadArtifactError) +async def bad_artifact_exception_handler(request: Request, exc: BadArtifactError) -> JSONResponse: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"message": "Bad artifact.", "detail": str(exc)}, + ) + + @app.get("/") async def read_root() -> bool: return True diff --git a/store/app/model.py b/store/app/model.py index b21c7ebd..8e65d13d 100644 --- a/store/app/model.py +++ b/store/app/model.py @@ -11,6 +11,7 @@ from pydantic import BaseModel +from store.app.errors import InternalError from store.app.utils.password import hash_password from store.settings import settings from store.utils import new_uuid @@ -201,27 +202,6 @@ def get_artifact_type(content_type: str, filename: str) -> ArtifactType: raise ValueError(f"Unknown content type for file: {filename}") -def get_artifact_name(id: str, artifact_type: ArtifactType, size: ArtifactSize = "large") -> str: - match artifact_type: - case "image": - if size is None: - raise ValueError("Image artifacts should have a size") - height, width = SizeMapping[size] - return f"{id}_{size}_{height}x{width}.png" - case "urdf": - return f"{id}.urdf" - case "mjcf": - return f"{id}.xml" - case "stl": - return f"{id}.stl" - case _: - raise ValueError(f"Unknown artifact type: {artifact_type}") - - -def get_artifact_url(id: str, artifact_type: ArtifactType, size: ArtifactSize = "large") -> str: - return f"{settings.site.artifact_base_url}/{get_artifact_name(id, artifact_type, size)}" - - def get_content_type(artifact_type: ArtifactType) -> str: return DOWNLOAD_CONTENT_TYPE[artifact_type] @@ -312,3 +292,50 @@ def create(cls, listing_id: str, tag: str) -> Self: listing_id=listing_id, name=tag, ) + + +def get_artifact_name( + *, + artifact: Artifact | None = None, + listing_id: str | None = None, + name: str | None = None, + artifact_type: ArtifactType | None = None, + size: ArtifactSize = "large", +) -> str: + if artifact: + listing_id = artifact.listing_id + name = artifact.name + artifact_type = artifact.artifact_type + elif not listing_id or not name or not artifact_type: + raise InternalError("Must provide artifact or listing_id, name, and artifact_type") + + match artifact_type: + case "image": + height, width = SizeMapping[size] + return f"{listing_id}/{size}_{height}x{width}_{name}" + case "urdf": + return f"{listing_id}/{name}" + case "mjcf": + return f"{listing_id}/{name}" + case "stl": + return f"{listing_id}/{name}" + case _: + raise ValueError(f"Unknown artifact type: {artifact_type}") + + +def get_artifact_url( + *, + artifact: Artifact | None = None, + artifact_type: ArtifactType | None = None, + listing_id: str | None = None, + name: str | None = None, + size: ArtifactSize = "large", +) -> str: + artifact_name = get_artifact_name( + artifact=artifact, + listing_id=listing_id, + name=name, + artifact_type=artifact_type, + size=size, + ) + return f"{settings.site.artifact_base_url}/{artifact_name}" diff --git a/store/app/routers/artifacts.py b/store/app/routers/artifacts.py index c2892f3a..0cdb6d5d 100644 --- a/store/app/routers/artifacts.py +++ b/store/app/routers/artifacts.py @@ -25,18 +25,27 @@ logger = logging.getLogger(__name__) -@artifacts_router.get("/url/{artifact_type}/{artifact_id}") +@artifacts_router.get("/url/{artifact_type}/{listing_id}/{name}") async def artifact_url( artifact_type: ArtifactType, - artifact_id: str, + listing_id: str, + name: str, size: ArtifactSize = "large", ) -> RedirectResponse: # TODO: Use CloudFront API to return a signed CloudFront URL. - return RedirectResponse(url=get_artifact_url(artifact_id, artifact_type, size)) + return RedirectResponse( + url=get_artifact_url( + artifact_type=artifact_type, + listing_id=listing_id, + name=name, + size=size, + ) + ) class ListArtifactsItem(BaseModel): artifact_id: str + listing_id: str name: str artifact_type: ArtifactType description: str | None @@ -54,11 +63,12 @@ async def list_artifacts(listing_id: str, crud: Annotated[Crud, Depends(Crud.get artifacts=[ ListArtifactsItem( artifact_id=artifact.id, + listing_id=artifact.listing_id, name=artifact.name, artifact_type=artifact.artifact_type, description=artifact.description, timestamp=artifact.timestamp, - url=get_artifact_url(artifact.id, artifact.artifact_type), + url=get_artifact_url(artifact=artifact), ) for artifact in await crud.get_listing_artifacts(listing_id) ], @@ -124,6 +134,13 @@ async def upload( data = UploadArtifactRequest.model_validate_json(metadata) filenames = [validate_file(file) for file in files] + # Makes sure that filenames are unique. + if len(set(filename for filename, _ in filenames)) != len(filenames): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Duplicate filenames were provided", + ) + # Checks that the listing is valid. listing = await crud.get_listing(data.listing_id) if listing is None: @@ -151,11 +168,12 @@ async def upload( artifacts=[ ListArtifactsItem( artifact_id=artifact.id, + listing_id=artifact.listing_id, name=artifact.name, artifact_type=artifact.artifact_type, description=artifact.description, timestamp=artifact.timestamp, - url=get_artifact_url(artifact.id, artifact.artifact_type), + url=get_artifact_url(artifact=artifact), ) for artifact in artifacts ] diff --git a/store/app/routers/listings.py b/store/app/routers/listings.py index 956075dd..62c69817 100644 --- a/store/app/routers/listings.py +++ b/store/app/routers/listings.py @@ -66,7 +66,7 @@ async def get_batch_listing_info( child_ids=listing.child_ids, image_url=next( ( - get_artifact_url(artifact.id, "image", "small") + get_artifact_url(artifact=artifact, size="small") for artifact in artifacts if artifact.artifact_type == "image" ), diff --git a/tests/assets/sample.urdf b/tests/assets/sample.urdf index 7f9a4aed..9df92174 100644 --- a/tests/assets/sample.urdf +++ b/tests/assets/sample.urdf @@ -18,7 +18,7 @@ - + diff --git a/tests/test_images.py b/tests/test_images.py index a826ffe9..2a9c7721 100644 --- a/tests/test_images.py +++ b/tests/test_images.py @@ -45,8 +45,9 @@ async def test_user_auth_functions(app_client: AsyncClient, tmpdir: Path) -> Non assert response.status_code == status.HTTP_200_OK, response.json() data = response.json() assert data["artifacts"] is not None - image_id = data["artifacts"][0]["artifact_id"] + listing_id = data["artifacts"][0]["listing_id"] + name = data["artifacts"][0]["name"] # Gets the URLs for various sizes of images. - response = await app_client.get(f"/artifacts/url/image/{image_id}", params={"size": "small"}) + response = await app_client.get(f"/artifacts/url/image/{listing_id}/{name}", params={"size": "small"}) assert response.status_code == status.HTTP_307_TEMPORARY_REDIRECT, response.json() diff --git a/tests/test_listings.py b/tests/test_listings.py index 426830d2..b53565e8 100644 --- a/tests/test_listings.py +++ b/tests/test_listings.py @@ -59,8 +59,9 @@ async def test_listings(app_client: AsyncClient, tmpdir: Path) -> None: assert data["artifacts"][0]["artifact_id"] is not None # Gets the URDF URL. - artifact_id = data["artifacts"][0]["artifact_id"] - response = await app_client.get(f"/artifacts/url/urdf/{artifact_id}", headers=auth_headers) + listing_id = data["artifacts"][0]["listing_id"] + name = data["artifacts"][0]["name"] + response = await app_client.get(f"/artifacts/url/urdf/{listing_id}/{name}", headers=auth_headers) assert response.status_code == status.HTTP_307_TEMPORARY_REDIRECT, response.content # Uploads an STL. @@ -78,6 +79,17 @@ async def test_listings(app_client: AsyncClient, tmpdir: Path) -> None: data = response.json() assert data["artifacts"][0]["artifact_id"] is not None + # Ensures that trying to upload the same STL again fails. + response = await app_client.post( + "/artifacts/upload", + headers=auth_headers, + files={ + "files": ("teapot.stl", open(stl_path, "rb"), "application/octet-stream"), + "metadata": (None, data_json), + }, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST, response.json() + # Searches for listings. response = await app_client.get( "/listings/search",