Skip to content

Commit

Permalink
Improve s3 proxy and fix zenodo upload
Browse files Browse the repository at this point in the history
  • Loading branch information
oeway committed Nov 23, 2024
1 parent 2298041 commit 927796b
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 133 deletions.
2 changes: 1 addition & 1 deletion hypha/VERSION
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"version": "0.20.39.post16"
"version": "0.20.39.post17"
}
10 changes: 10 additions & 0 deletions hypha/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -984,6 +984,15 @@ async def create(
deposition_info["conceptrecid"]
)
config["zenodo"] = deposition_info
config["publish_to"] = publish_to

if publish_to not in ["zenodo", "sandbox_zenodo"]:
assert (
"{zenodo_id}" not in alias
), "Alias cannot contain the '{zenodo_id}' placeholder, set publish_to to 'zenodo' or 'sandbox_zenodo'."
assert (
"{zenodo_conceptrecid}" not in alias
), "Alias cannot contain the '{zenodo_conceptrecid}' placeholder, set publish_to to 'zenodo' or 'sandbox_zenodo'."

if parent_artifact and parent_artifact.config:
id_parts.update(parent_artifact.config.get("id_parts", {}))
Expand Down Expand Up @@ -2331,6 +2340,7 @@ async def publish(
assert "description" in manifest, "Manifest must have a description."

config = artifact.config or {}
to = to or config.get("publish_to")
zenodo_client = self._get_zenodo_client(
artifact, parent_artifact, publish_to=to
)
Expand Down
128 changes: 16 additions & 112 deletions hypha/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
import botocore
from aiobotocore.session import get_session
from botocore.exceptions import ClientError
from fastapi import APIRouter, Depends, Request, HTTPException
from fastapi import APIRouter, Depends, Request
from fastapi.responses import FileResponse, Response, StreamingResponse, JSONResponse
from starlette.datastructures import Headers
from starlette.types import Receive, Scope, Send
from asgiproxy.simple_proxy import make_simple_proxy_app
from asgiproxy.context import ProxyContext
from asgiproxy.config import BaseURLProxyConfigMixin, ProxyConfig

from hypha.core import UserInfo, WorkspaceInfo, UserPermission
from hypha.minio import MinioClient
Expand Down Expand Up @@ -367,118 +370,19 @@ async def get_zip_file_content(

if self.enable_s3_proxy:

@router.api_route(
"/s3/{path:path}",
methods=["GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS"],
)
async def proxy_s3_request(
path: str,
request: Request,
):
"""
Proxy request to S3, forwarding only essential headers such as Range,
and handling different methods (GET, POST, PUT).
"""
# Extract the query parameters from the client request
query_params = dict(request.query_params)

# Construct the S3 presigned URL using the internal endpoint_url
if query_params:
s3_url = f"{self.endpoint_url}/{path}?" + urlencode(query_params)
else:
s3_url = f"{self.endpoint_url}/{path}"

# Define the method
method = request.method

# Create a case-insensitive headers dictionary
incoming_headers = dict(request.headers)
normalized_headers = {
key.lower(): value for key, value in incoming_headers.items()
}

# Keep only the essential headers for the request
essential_headers = {}
if "range" in normalized_headers:
essential_headers["Range"] = normalized_headers["range"]

# Add content-type and content-length only for upload requests (POST/PUT)
if method in ["POST", "PUT"]:
if "content-type" in normalized_headers:
essential_headers["Content-Type"] = normalized_headers[
"content-type"
]
if "content-length" in normalized_headers:
essential_headers["Content-Length"] = normalized_headers[
"content-length"
]

# Stream data to/from S3
async with httpx.AsyncClient() as client:
try:
# For methods like POST/PUT, pass the request body
if method in ["POST", "PUT", "PATCH"]:
# Read and stream the request body in chunks
async def request_body_stream():
async for chunk in request.stream():
yield chunk

response = await client.request(
method,
s3_url,
content=request_body_stream(), # Stream the request body to S3
headers=essential_headers, # Forward essential headers
timeout=None, # Remove timeout for large file uploads
)
else:
response = await client.request(
method,
s3_url,
headers=essential_headers, # Forward essential headers
timeout=None,
)

# Return the response, stream data for GET requests
if method == "GET":
return StreamingResponse(
response.iter_bytes(), # Async iterator of response body chunks
status_code=response.status_code,
headers={
k: v
for k, v in response.headers.items()
if k.lower()
not in ["content-encoding", "transfer-encoding"]
},
)

elif method in ["POST", "PATCH", "PUT", "DELETE"]:
return Response(
content=response.content, # Raw response content
status_code=response.status_code,
headers=response.headers, # Pass raw headers from the response
)

elif method == "HEAD":
return Response(
status_code=response.status_code,
headers=response.headers, # No content for HEAD, just forward headers
)

else:
return Response(
status_code=405, content="Method Not Allowed"
)
class S3ProxyConfig(BaseURLProxyConfigMixin, ProxyConfig):
# Set your S3 root endpoint
upstream_base_url = self.endpoint_url
rewrite_host_header = (
self.endpoint_url.replace("https://", "")
.replace("http://", "")
.split("/")[0]
)

except httpx.HTTPStatusError as exc:
raise HTTPException(
status_code=exc.response.status_code,
detail=f"Error while proxying to S3: {exc.response.text}",
)
except Exception as exc:
raise HTTPException(
status_code=500,
detail=f"Internal server error: {str(exc)}",
)
config = S3ProxyConfig()
context = ProxyContext(config=config)
s3_app = make_simple_proxy_app(context, proxy_websocket_handler=None)
self.store.mount_app("/s3", s3_app, "s3-proxy")

@router.get("/{workspace}/files/{path:path}")
@router.delete("/{workspace}/files/{path:path}")
Expand Down
107 changes: 106 additions & 1 deletion hypha/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import string
import time
from datetime import datetime
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Dict
import shortuuid
import friendlywords as fw
import logging

import httpx
from fastapi.routing import APIRoute
Expand Down Expand Up @@ -488,3 +489,107 @@ async def _example_hypha_startup(server):
"test": lambda x: x + 22,
}
)


async def chunked_transfer_remote_file(
source_url: str,
target_url: str,
source_params: Optional[Dict] = None,
target_params: Optional[Dict] = None,
source_headers: Optional[Dict] = None,
target_headers: Optional[Dict] = None,
logger: Optional[logging.Logger] = None,
chunk_size: int = 2048,
timeout: int = 60,
):
"""
Transfers a file from a remote source to a remote target without fully buffering in memory.
Args:
source_url (str): The URL of the source file.
target_url (str): The URL where the file will be uploaded.
source_params (Dict, optional): Additional query parameters for the source request.
target_params (Dict, optional): Additional query parameters for the target request.
source_headers (Dict, optional): Headers for the source request.
target_headers (Dict, optional): Headers for the target request.
logger (logging.Logger, optional): Logger for logging messages.
chunk_size (int, optional): Size of chunks for streaming. Default is 2048 bytes.
timeout (int, optional): Timeout for the HTTP requests in seconds. Default is 60.
"""
source_params = source_params or {}
target_params = target_params or {}
source_headers = source_headers or {}
target_headers = target_headers or {}
logger = logger or logging.getLogger(__name__)

async with httpx.AsyncClient(
headers={"Connection": "close"}, timeout=timeout
) as client:
# Step 1: Get file size from the source
range_headers = {"Range": "bytes=0-0"}
range_headers.update(source_headers)

logger.info(f"Fetching file size from {source_url}...")
range_response = await client.get(
source_url, headers=range_headers, params=source_params
)
if range_response.status_code not in [200, 206]:
logger.error(
f"Failed to fetch file size. Status code: {range_response.status_code}"
)
logger.error(f"Response: {range_response.text}")
return

content_range = range_response.headers.get("Content-Range")
if not content_range or "/" not in content_range:
logger.error("Content-Range header is missing or invalid.")
return

file_size = int(content_range.split("/")[-1])
logger.info(f"File size: {file_size} bytes")

# Step 2: Stream the file from source
async def s3_response_chunk_reader(response):
async for chunk in response.aiter_bytes(chunk_size):
yield chunk

async with client.stream(
"GET", source_url, headers=source_headers, params=source_params
) as response:
if response.status_code != 200:
info = await response.aread()
info = info.decode("utf-8") if info else ""
logger.error(
f"Failed to download the file. Status code: {response.status_code}, {info}"
)
raise Exception(
f"Failed to download the file. Status code: {response.status_code}, {info}"
)
info = await response.aread()
# Step 3: Upload to target with content-length
target_headers.update(
{
"Content-Type": "application/octet-stream",
"Content-Length": str(file_size),
}
)

async def upload_generator():
async for chunk in s3_response_chunk_reader(response):
yield chunk

logger.info(f"Uploading file to {target_url}...")
put_response = await client.put(
target_url,
headers=target_headers,
data=upload_generator(),
params=target_params,
)

if put_response.status_code >= 400:
logger.error(
f"Failed to upload the file. Status code: {put_response.status_code}, {put_response.text}"
)
raise Exception(
f"Failed to upload the file. Status code: {put_response.status_code}, {put_response.text}"
)
24 changes: 6 additions & 18 deletions hypha/utils/zenodo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Dict, Any, List, Optional
import aiofiles
from pathlib import PurePosixPath

from hypha.utils import chunked_transfer_remote_file

ZENODO_TIMEOUT = 30 # seconds

Expand Down Expand Up @@ -100,24 +100,12 @@ async def file_chunk_reader(file_path: str, chunk_size: int = 1024):
data=file_chunk_reader(file_path),
)

async def import_file(self, deposition_info, name, target_url):
async def import_file(self, deposition_info, name, source_url):
bucket_url = deposition_info["links"]["bucket"]
async with httpx.AsyncClient(
headers={"Connection": "close"}, timeout=ZENODO_TIMEOUT
) as client:
async with client.stream("GET", target_url) as response:

async def s3_response_chunk_reader(response, chunk_size: int = 2048):
async for chunk in response.aiter_bytes(chunk_size):
yield chunk

put_response = await self.client.put(
f"{bucket_url}/{name}",
params=self.params,
headers={"Content-Type": "application/octet-stream"},
data=s3_response_chunk_reader(response),
)
put_response.raise_for_status()
target_url = f"{bucket_url}/{name}"
await chunked_transfer_remote_file(
source_url, target_url, target_params=self.params
)

async def delete_deposition(self, deposition_id: str) -> None:
"""Deletes a deposition. Only unpublished depositions can be deleted."""
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ hrid==0.2.4
qdrant-client==1.12.1
ollama==0.3.3
fastembed==0.4.2
asgiproxy==0.1.1
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"sqlmodel>=0.0.22",
"alembic>=1.14.0",
"hrid>=0.2.4",
"asgiproxy>=0.1.1",
]

ROOT_DIR = Path(__file__).parent.resolve()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ async def test_publish_artifact(minio_server, fastapi_server, test_user_token):
)
source = "file contents of example.txt"
response = requests.put(put_url, data=source)
assert response.ok
assert response.ok, response.text

# Commit the artifact after adding the file
await artifact_manager.commit(artifact_id=dataset.id)
Expand Down

0 comments on commit 927796b

Please sign in to comment.