Skip to content

Commit

Permalink
add an OutputURL type with a streaming download/upload
Browse files Browse the repository at this point in the history
  • Loading branch information
technillogue committed Oct 24, 2024
1 parent 426b112 commit 0bd49a9
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 8 deletions.
26 changes: 18 additions & 8 deletions python/cog/server/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Dict,
Mapping,
Optional,
Union,
cast,
)
from urllib.parse import urlparse
Expand All @@ -21,7 +22,7 @@

from .. import types
from ..schema import PredictionResponse, Status, WebhookEvent
from ..types import Path
from ..types import OutputURL, Path
from .eventtypes import PredictionInput
from .response_throttler import ResponseThrottler
from .retry_transport import RetryTransport
Expand Down Expand Up @@ -199,7 +200,11 @@ async def sender(response: PredictionResponse, event: WebhookEvent) -> None:
# files

async def upload_file(
self, fh: io.IOBase, *, url: Optional[str], prediction_id: Optional[str]
self,
fh: Union[io.IOBase, OutputURL],
*,
url: Optional[str],
prediction_id: Optional[str],
) -> str:
"""put file to signed endpoint"""
log.debug("upload_file")
Expand All @@ -213,7 +218,7 @@ async def upload_file(

# this code path happens when running outside replicate without upload-url
# in that case we need to return data uris
if url is None:
if url is None and not isinstance(fh, OutputURL):
return file_to_data_uri(fh, content_type)
assert url

Expand Down Expand Up @@ -243,11 +248,14 @@ async def upload_file(
url = resp1.headers["Location"]

log.info("doing real upload to %s", url)
resp = await self.file_client.put(
url,
content=ChunkFileReader(fh),
headers=headers,
)
if isinstance(fh, OutputURL):
async with self.file_client.stream("GET", fh.url) as resp:
content = resp.aiter_bytes()
resp = await self.file_client.put(url, content=content, headers=headers)
else:
resp = await self.file_client.put(
url, content=ChunkFileReader(fh), headers=headers
)
# TODO: if file size is >1MB, show upload throughput
resp.raise_for_status()

Expand Down Expand Up @@ -292,6 +300,8 @@ async def upload_files(
if isinstance(obj, Path):
with obj.open("rb") as f:
return await self.upload_file(f, url=url, prediction_id=prediction_id)
if isinstance(obj, OutputURL):
return await self.upload_file(obj, url=url, prediction_id=prediction_id)
if isinstance(obj, io.IOBase):
with obj:
return await self.upload_file(obj, url=url, prediction_id=prediction_id)
Expand Down
17 changes: 17 additions & 0 deletions python/cog/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,23 @@ def unlink(self, missing_ok: bool = False) -> None:
raise


class OutputURL:
def __init__(self, url: str, filename: Optional[str] = None) -> None:
parsed = urllib.parse.urlparse(url)
if parsed.scheme not in {
"http",
"https",
}:
raise ValueError(
"URLFile requires URL to conform to HTTP or HTTPS protocol"
)

if not filename:
filename = os.path.basename(parsed.path)
self.name = filename
self.url = url


class URLFile(io.IOBase):
"""
URLFile is a proxy object for a :class:`urllib3.response.HTTPResponse`
Expand Down

0 comments on commit 0bd49a9

Please sign in to comment.