From 927796beefa6934d645eafc00a9cc6200399e0c9 Mon Sep 17 00:00:00 2001 From: Wei Ouyang Date: Sat, 23 Nov 2024 11:37:30 -0800 Subject: [PATCH] Improve s3 proxy and fix zenodo upload --- hypha/VERSION | 2 +- hypha/artifact.py | 10 ++++ hypha/s3.py | 128 +++++----------------------------------- hypha/utils/__init__.py | 107 ++++++++++++++++++++++++++++++++- hypha/utils/zenodo.py | 24 ++------ requirements.txt | 1 + setup.py | 1 + tests/test_artifact.py | 2 +- 8 files changed, 142 insertions(+), 133 deletions(-) diff --git a/hypha/VERSION b/hypha/VERSION index 7b255f05..4d519499 100644 --- a/hypha/VERSION +++ b/hypha/VERSION @@ -1,3 +1,3 @@ { - "version": "0.20.39.post16" + "version": "0.20.39.post17" } diff --git a/hypha/artifact.py b/hypha/artifact.py index 8b4cb814..539564bb 100644 --- a/hypha/artifact.py +++ b/hypha/artifact.py @@ -984,6 +984,15 @@ async def create( deposition_info["conceptrecid"] ) config["zenodo"] = deposition_info + config["publish_to"] = publish_to + + if publish_to not in ["zenodo", "sandbox_zenodo"]: + assert ( + "{zenodo_id}" not in alias + ), "Alias cannot contain the '{zenodo_id}' placeholder, set publish_to to 'zenodo' or 'sandbox_zenodo'." + assert ( + "{zenodo_conceptrecid}" not in alias + ), "Alias cannot contain the '{zenodo_conceptrecid}' placeholder, set publish_to to 'zenodo' or 'sandbox_zenodo'." if parent_artifact and parent_artifact.config: id_parts.update(parent_artifact.config.get("id_parts", {})) @@ -2331,6 +2340,7 @@ async def publish( assert "description" in manifest, "Manifest must have a description." config = artifact.config or {} + to = to or config.get("publish_to") zenodo_client = self._get_zenodo_client( artifact, parent_artifact, publish_to=to ) diff --git a/hypha/s3.py b/hypha/s3.py index 3b8bba63..6d15f1a7 100644 --- a/hypha/s3.py +++ b/hypha/s3.py @@ -16,10 +16,13 @@ import botocore from aiobotocore.session import get_session from botocore.exceptions import ClientError -from fastapi import APIRouter, Depends, Request, HTTPException +from fastapi import APIRouter, Depends, Request from fastapi.responses import FileResponse, Response, StreamingResponse, JSONResponse from starlette.datastructures import Headers from starlette.types import Receive, Scope, Send +from asgiproxy.simple_proxy import make_simple_proxy_app +from asgiproxy.context import ProxyContext +from asgiproxy.config import BaseURLProxyConfigMixin, ProxyConfig from hypha.core import UserInfo, WorkspaceInfo, UserPermission from hypha.minio import MinioClient @@ -367,118 +370,19 @@ async def get_zip_file_content( if self.enable_s3_proxy: - @router.api_route( - "/s3/{path:path}", - methods=["GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS"], - ) - async def proxy_s3_request( - path: str, - request: Request, - ): - """ - Proxy request to S3, forwarding only essential headers such as Range, - and handling different methods (GET, POST, PUT). - """ - # Extract the query parameters from the client request - query_params = dict(request.query_params) - - # Construct the S3 presigned URL using the internal endpoint_url - if query_params: - s3_url = f"{self.endpoint_url}/{path}?" + urlencode(query_params) - else: - s3_url = f"{self.endpoint_url}/{path}" - - # Define the method - method = request.method - - # Create a case-insensitive headers dictionary - incoming_headers = dict(request.headers) - normalized_headers = { - key.lower(): value for key, value in incoming_headers.items() - } - - # Keep only the essential headers for the request - essential_headers = {} - if "range" in normalized_headers: - essential_headers["Range"] = normalized_headers["range"] - - # Add content-type and content-length only for upload requests (POST/PUT) - if method in ["POST", "PUT"]: - if "content-type" in normalized_headers: - essential_headers["Content-Type"] = normalized_headers[ - "content-type" - ] - if "content-length" in normalized_headers: - essential_headers["Content-Length"] = normalized_headers[ - "content-length" - ] - - # Stream data to/from S3 - async with httpx.AsyncClient() as client: - try: - # For methods like POST/PUT, pass the request body - if method in ["POST", "PUT", "PATCH"]: - # Read and stream the request body in chunks - async def request_body_stream(): - async for chunk in request.stream(): - yield chunk - - response = await client.request( - method, - s3_url, - content=request_body_stream(), # Stream the request body to S3 - headers=essential_headers, # Forward essential headers - timeout=None, # Remove timeout for large file uploads - ) - else: - response = await client.request( - method, - s3_url, - headers=essential_headers, # Forward essential headers - timeout=None, - ) - - # Return the response, stream data for GET requests - if method == "GET": - return StreamingResponse( - response.iter_bytes(), # Async iterator of response body chunks - status_code=response.status_code, - headers={ - k: v - for k, v in response.headers.items() - if k.lower() - not in ["content-encoding", "transfer-encoding"] - }, - ) - - elif method in ["POST", "PATCH", "PUT", "DELETE"]: - return Response( - content=response.content, # Raw response content - status_code=response.status_code, - headers=response.headers, # Pass raw headers from the response - ) - - elif method == "HEAD": - return Response( - status_code=response.status_code, - headers=response.headers, # No content for HEAD, just forward headers - ) - - else: - return Response( - status_code=405, content="Method Not Allowed" - ) + class S3ProxyConfig(BaseURLProxyConfigMixin, ProxyConfig): + # Set your S3 root endpoint + upstream_base_url = self.endpoint_url + rewrite_host_header = ( + self.endpoint_url.replace("https://", "") + .replace("http://", "") + .split("/")[0] + ) - except httpx.HTTPStatusError as exc: - raise HTTPException( - status_code=exc.response.status_code, - detail=f"Error while proxying to S3: {exc.response.text}", - ) - except Exception as exc: - raise HTTPException( - status_code=500, - detail=f"Internal server error: {str(exc)}", - ) + config = S3ProxyConfig() + context = ProxyContext(config=config) + s3_app = make_simple_proxy_app(context, proxy_websocket_handler=None) + self.store.mount_app("/s3", s3_app, "s3-proxy") @router.get("/{workspace}/files/{path:path}") @router.delete("/{workspace}/files/{path:path}") diff --git a/hypha/utils/__init__.py b/hypha/utils/__init__.py index 899eff05..433eae02 100644 --- a/hypha/utils/__init__.py +++ b/hypha/utils/__init__.py @@ -8,9 +8,10 @@ import string import time from datetime import datetime -from typing import Callable, List, Optional +from typing import Callable, List, Optional, Dict import shortuuid import friendlywords as fw +import logging import httpx from fastapi.routing import APIRoute @@ -488,3 +489,107 @@ async def _example_hypha_startup(server): "test": lambda x: x + 22, } ) + + +async def chunked_transfer_remote_file( + source_url: str, + target_url: str, + source_params: Optional[Dict] = None, + target_params: Optional[Dict] = None, + source_headers: Optional[Dict] = None, + target_headers: Optional[Dict] = None, + logger: Optional[logging.Logger] = None, + chunk_size: int = 2048, + timeout: int = 60, +): + """ + Transfers a file from a remote source to a remote target without fully buffering in memory. + + Args: + source_url (str): The URL of the source file. + target_url (str): The URL where the file will be uploaded. + source_params (Dict, optional): Additional query parameters for the source request. + target_params (Dict, optional): Additional query parameters for the target request. + source_headers (Dict, optional): Headers for the source request. + target_headers (Dict, optional): Headers for the target request. + logger (logging.Logger, optional): Logger for logging messages. + chunk_size (int, optional): Size of chunks for streaming. Default is 2048 bytes. + timeout (int, optional): Timeout for the HTTP requests in seconds. Default is 60. + """ + source_params = source_params or {} + target_params = target_params or {} + source_headers = source_headers or {} + target_headers = target_headers or {} + logger = logger or logging.getLogger(__name__) + + async with httpx.AsyncClient( + headers={"Connection": "close"}, timeout=timeout + ) as client: + # Step 1: Get file size from the source + range_headers = {"Range": "bytes=0-0"} + range_headers.update(source_headers) + + logger.info(f"Fetching file size from {source_url}...") + range_response = await client.get( + source_url, headers=range_headers, params=source_params + ) + if range_response.status_code not in [200, 206]: + logger.error( + f"Failed to fetch file size. Status code: {range_response.status_code}" + ) + logger.error(f"Response: {range_response.text}") + return + + content_range = range_response.headers.get("Content-Range") + if not content_range or "/" not in content_range: + logger.error("Content-Range header is missing or invalid.") + return + + file_size = int(content_range.split("/")[-1]) + logger.info(f"File size: {file_size} bytes") + + # Step 2: Stream the file from source + async def s3_response_chunk_reader(response): + async for chunk in response.aiter_bytes(chunk_size): + yield chunk + + async with client.stream( + "GET", source_url, headers=source_headers, params=source_params + ) as response: + if response.status_code != 200: + info = await response.aread() + info = info.decode("utf-8") if info else "" + logger.error( + f"Failed to download the file. Status code: {response.status_code}, {info}" + ) + raise Exception( + f"Failed to download the file. Status code: {response.status_code}, {info}" + ) + info = await response.aread() + # Step 3: Upload to target with content-length + target_headers.update( + { + "Content-Type": "application/octet-stream", + "Content-Length": str(file_size), + } + ) + + async def upload_generator(): + async for chunk in s3_response_chunk_reader(response): + yield chunk + + logger.info(f"Uploading file to {target_url}...") + put_response = await client.put( + target_url, + headers=target_headers, + data=upload_generator(), + params=target_params, + ) + + if put_response.status_code >= 400: + logger.error( + f"Failed to upload the file. Status code: {put_response.status_code}, {put_response.text}" + ) + raise Exception( + f"Failed to upload the file. Status code: {put_response.status_code}, {put_response.text}" + ) diff --git a/hypha/utils/zenodo.py b/hypha/utils/zenodo.py index e306af00..e8f429a0 100644 --- a/hypha/utils/zenodo.py +++ b/hypha/utils/zenodo.py @@ -3,7 +3,7 @@ from typing import Dict, Any, List, Optional import aiofiles from pathlib import PurePosixPath - +from hypha.utils import chunked_transfer_remote_file ZENODO_TIMEOUT = 30 # seconds @@ -100,24 +100,12 @@ async def file_chunk_reader(file_path: str, chunk_size: int = 1024): data=file_chunk_reader(file_path), ) - async def import_file(self, deposition_info, name, target_url): + async def import_file(self, deposition_info, name, source_url): bucket_url = deposition_info["links"]["bucket"] - async with httpx.AsyncClient( - headers={"Connection": "close"}, timeout=ZENODO_TIMEOUT - ) as client: - async with client.stream("GET", target_url) as response: - - async def s3_response_chunk_reader(response, chunk_size: int = 2048): - async for chunk in response.aiter_bytes(chunk_size): - yield chunk - - put_response = await self.client.put( - f"{bucket_url}/{name}", - params=self.params, - headers={"Content-Type": "application/octet-stream"}, - data=s3_response_chunk_reader(response), - ) - put_response.raise_for_status() + target_url = f"{bucket_url}/{name}" + await chunked_transfer_remote_file( + source_url, target_url, target_params=self.params + ) async def delete_deposition(self, deposition_id: str) -> None: """Deletes a deposition. Only unpublished depositions can be deleted.""" diff --git a/requirements.txt b/requirements.txt index d628f213..6c257984 100644 --- a/requirements.txt +++ b/requirements.txt @@ -37,3 +37,4 @@ hrid==0.2.4 qdrant-client==1.12.1 ollama==0.3.3 fastembed==0.4.2 +asgiproxy==0.1.1 \ No newline at end of file diff --git a/setup.py b/setup.py index 5f5cbea1..81316851 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,7 @@ "sqlmodel>=0.0.22", "alembic>=1.14.0", "hrid>=0.2.4", + "asgiproxy>=0.1.1", ] ROOT_DIR = Path(__file__).parent.resolve() diff --git a/tests/test_artifact.py b/tests/test_artifact.py index b8767214..695ea27b 100644 --- a/tests/test_artifact.py +++ b/tests/test_artifact.py @@ -715,7 +715,7 @@ async def test_publish_artifact(minio_server, fastapi_server, test_user_token): ) source = "file contents of example.txt" response = requests.put(put_url, data=source) - assert response.ok + assert response.ok, response.text # Commit the artifact after adding the file await artifact_manager.commit(artifact_id=dataset.id)