From 0bd49a950369d6e059da7960dbd1a7feeb562cbb Mon Sep 17 00:00:00 2001 From: technillogue Date: Wed, 23 Oct 2024 22:36:04 -0400 Subject: [PATCH] add an OutputURL type with a streaming download/upload --- python/cog/server/clients.py | 26 ++++++++++++++++++-------- python/cog/types.py | 17 +++++++++++++++++ 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/python/cog/server/clients.py b/python/cog/server/clients.py index ae30f019e7..5187185c9a 100644 --- a/python/cog/server/clients.py +++ b/python/cog/server/clients.py @@ -11,6 +11,7 @@ Dict, Mapping, Optional, + Union, cast, ) from urllib.parse import urlparse @@ -21,7 +22,7 @@ from .. import types from ..schema import PredictionResponse, Status, WebhookEvent -from ..types import Path +from ..types import OutputURL, Path from .eventtypes import PredictionInput from .response_throttler import ResponseThrottler from .retry_transport import RetryTransport @@ -199,7 +200,11 @@ async def sender(response: PredictionResponse, event: WebhookEvent) -> None: # files async def upload_file( - self, fh: io.IOBase, *, url: Optional[str], prediction_id: Optional[str] + self, + fh: Union[io.IOBase, OutputURL], + *, + url: Optional[str], + prediction_id: Optional[str], ) -> str: """put file to signed endpoint""" log.debug("upload_file") @@ -213,7 +218,7 @@ async def upload_file( # this code path happens when running outside replicate without upload-url # in that case we need to return data uris - if url is None: + if url is None and not isinstance(fh, OutputURL): return file_to_data_uri(fh, content_type) assert url @@ -243,11 +248,14 @@ async def upload_file( url = resp1.headers["Location"] log.info("doing real upload to %s", url) - resp = await self.file_client.put( - url, - content=ChunkFileReader(fh), - headers=headers, - ) + if isinstance(fh, OutputURL): + async with self.file_client.stream("GET", fh.url) as resp: + content = resp.aiter_bytes() + resp = await self.file_client.put(url, content=content, headers=headers) + else: + resp = await self.file_client.put( + url, content=ChunkFileReader(fh), headers=headers + ) # TODO: if file size is >1MB, show upload throughput resp.raise_for_status() @@ -292,6 +300,8 @@ async def upload_files( if isinstance(obj, Path): with obj.open("rb") as f: return await self.upload_file(f, url=url, prediction_id=prediction_id) + if isinstance(obj, OutputURL): + return await self.upload_file(obj, url=url, prediction_id=prediction_id) if isinstance(obj, io.IOBase): with obj: return await self.upload_file(obj, url=url, prediction_id=prediction_id) diff --git a/python/cog/types.py b/python/cog/types.py index 75a3d451b0..8597387cea 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -194,6 +194,23 @@ def unlink(self, missing_ok: bool = False) -> None: raise +class OutputURL: + def __init__(self, url: str, filename: Optional[str] = None) -> None: + parsed = urllib.parse.urlparse(url) + if parsed.scheme not in { + "http", + "https", + }: + raise ValueError( + "URLFile requires URL to conform to HTTP or HTTPS protocol" + ) + + if not filename: + filename = os.path.basename(parsed.path) + self.name = filename + self.url = url + + class URLFile(io.IOBase): """ URLFile is a proxy object for a :class:`urllib3.response.HTTPResponse`