diff --git a/server/main.py b/server/main.py index 42ef452..c786164 100644 --- a/server/main.py +++ b/server/main.py @@ -1,7 +1,7 @@ -from typing import List, Literal +from typing import List, TypedDict from fastapi import FastAPI, HTTPException, Request from fastapi.param_functions import Depends -from fastapi.responses import StreamingResponse +from fastapi.responses import StreamingResponse, FileResponse from fastapi.security import HTTPBearer, http from pydantic import BaseModel from requests.models import HTTPError @@ -29,18 +29,7 @@ def read_root(request: Request): return {"message": f"Hello world, {request.client.host}"} -class TapisFile(BaseModel): - path: str - type: Literal["file", "folder"] - - -class ZipRequest(BaseModel): - system: str - path: str - files: List[TapisFile] - - -def get_base_path(system: str, path: str): +def get_system_root(system: str) -> str: match system: case "designsafe.storage.default": root_dir = "/corral-repl/tacc/NHERI/shared" @@ -51,39 +40,60 @@ def get_base_path(system: str, path: str): case prj_system if system.startswith("project-"): project_id = prj_system.split("-", 1)[1] root_dir = os.path.join("/corral-repl/tacc/NHERI/projects", project_id) + case _: + raise HTTPException(status_code=404, detail="Invalid storage system ID.") + + return root_dir + - return os.path.join(root_dir, path.strip("/")) +def raise_for_size(size: int, max_size: int = 2e9) -> None: + if size > max_size: + raise HTTPException(status_code=413, detail="Archive size is limited to 2Gb.") -def walk_paths(base_path: str, files: List[TapisFile]): +class Archive(TypedDict): + fs: str # Represents the absolute path to a file on the host machine. + n: str # Represents the path to a file relative to the zip archive root. + + +def walk_archive_paths(base_path: str, file_paths: List[str]) -> List[Archive]: base = Path(base_path) - paths = [] + zip_paths = [] size = 0 - for file in files: - mount_path = base / file.path.strip("/") - if file.type == "file": - paths.append({"fs": str(mount_path), "n": mount_path.name}) - elif file.type == "folder": - for file_path in filter(lambda f: f.is_file(), mount_path.glob("**/*")): - paths.append( - {"fs": str(file_path), "n": str(file_path.relative_to(mount_path))} + for file in file_paths: + full_path = base / file.strip("/") + + if full_path.is_file(): + zip_paths.append({"fs": str(full_path), "n": full_path.name}) + size += full_path.stat().st_size + raise_for_size(size) + + elif full_path.is_dir(): + for file_path in filter(lambda f: f.is_file(), full_path.glob("**/*")): + zip_paths.append( + { + "fs": str(file_path), + "n": str(file_path.relative_to(full_path.parent)), + } ) size += file_path.stat().st_size - if size > 2e9: - raise HTTPException( - status_code=413, detail="Archive size limited to 2Gb." - ) + raise_for_size(size) - return paths + return zip_paths class CheckResponse(BaseModel): key: str +class CheckRequest(BaseModel): + system: str + paths: List[str] + + @app.put("/check", response_model=CheckResponse) def check_downloadable( - request: ZipRequest, + request: CheckRequest, auth: http.HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), ): PUBLIC_SYSTEMS = ["designsafe.storage.community", "designsafe.storage.published"] @@ -114,8 +124,8 @@ def check_downloadable( print(resp.content) raise HTTPException(status_code=403, detail=resp.json()) """ - base_path = get_base_path(request.system, request.path) - paths = walk_paths(base_path, request.files) + system_root = get_system_root(request.system) + paths = walk_archive_paths(system_root, request.paths) key = str(uuid4()) r.set(key, json.dumps(paths)) @@ -138,13 +148,20 @@ def download_file(key: str): if not key_json: raise HTTPException(status_code=404, detail="Invalid download link.") - paths = json.loads(key_json) + paths: List[Archive] = json.loads(key_json) r.delete(key) + if len(paths) == 1: + # If there's only 1 file to return, download it directly instead of zipping it. + return FileResponse( + paths[0]["fs"], + headers={"Content-Disposition": f"attachment; filename={paths[0]['n']}"}, + ) + zfly = zipfly.ZipFly(paths=paths) generator = zfly.generator() return StreamingResponse( generator, headers={"Content-Disposition": "attachment; filename=download.zip"}, - media_type="application/octet-stream", + media_type="application/zip", ) diff --git a/server/tests/util_test.py b/server/tests/util_test.py index 9473225..8c71986 100644 --- a/server/tests/util_test.py +++ b/server/tests/util_test.py @@ -1,51 +1,63 @@ -from server.main import get_base_path, walk_paths, TapisFile +import pytest import os +from fastapi.exceptions import HTTPException +from server.main import get_system_root, walk_archive_paths -def test_get_base_path(): +def test_get_system_root(): assert ( - get_base_path("designsafe.storage.default", "/path/to/file") - == "/corral-repl/tacc/NHERI/shared/path/to/file" + get_system_root("designsafe.storage.default") + == "/corral-repl/tacc/NHERI/shared" ) assert ( - get_base_path("designsafe.storage.default", "path/to/file") - == "/corral-repl/tacc/NHERI/shared/path/to/file" + get_system_root("designsafe.storage.community") + == "/corral-repl/tacc/NHERI/community" ) assert ( - get_base_path("designsafe.storage.community", "/path/to/file") - == "/corral-repl/tacc/NHERI/community/path/to/file" + get_system_root("designsafe.storage.published") + == "/corral-repl/tacc/NHERI/published" ) assert ( - get_base_path("designsafe.storage.published", "/path/to/file") - == "/corral-repl/tacc/NHERI/published/path/to/file" + get_system_root("project-7448086614930166251-242ac113-0001-012") + == "/corral-repl/tacc/NHERI/projects/7448086614930166251-242ac113-0001-012" ) - assert ( - get_base_path("project-7448086614930166251-242ac113-0001-012", "/path/to/file") - == "/corral-repl/tacc/NHERI/projects/7448086614930166251-242ac113-0001-012/path/to/file" - ) - - -def test_walk_paths(tmp_path): + with pytest.raises(HTTPException): + get_system_root("not-a-real-system") + + +def test_walk_archive(tmp_path): + """ + Generate and walk the following directory structure at tmp_path: + . + ├── f1.txt + └── sub_path/ + ├── f2.txt + └── sub_path2/ + └── f3.txt + """ base = tmp_path - sub = base / "sub_path" - os.mkdir(sub) f1 = base / "f1.txt" f1.write_text("CONTENT 1") + sub = base / "sub_path" + os.mkdir(sub) f2 = sub / "f2.txt" f2.write_text("CONTENT 2") - requested_files = [ - TapisFile(**{ - "path": "", - "type": "folder" - }) + sub2 = sub / "sub_path2" + os.mkdir(sub2) + f3 = sub2 / "f3.txt" + f3.write_text("CONTENT 3") + + requested_files = ["f1.txt", "sub_path"] + walk_result = walk_archive_paths(base, requested_files) + + assert walk_result == [ + {"fs": str(base / "f1.txt"), "n": "f1.txt"}, + {"fs": str(base / "sub_path" / "f2.txt"), "n": "sub_path/f2.txt"}, + { + "fs": str(base / "sub_path" / "sub_path2" / "f3.txt"), + "n": "sub_path/sub_path2/f3.txt", + }, ] - - paths2 = walk_paths(base, requested_files) - assert paths2 == [ - {'fs': str(base / "f1.txt"), 'n': 'f1.txt'}, - {'fs': str(base / "sub_path" / "f2.txt"), 'n': "sub_path/f2.txt"} - ] -