Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add an OutputURL type with a streaming download/upload #2021

Open
wants to merge 2 commits into
base: async
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions python/cog/server/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Dict,
Mapping,
Optional,
Union,
cast,
)
from urllib.parse import urlparse
Expand All @@ -21,7 +22,7 @@

from .. import types
from ..schema import PredictionResponse, Status, WebhookEvent
from ..types import Path
from ..types import OutputURL, Path
from .eventtypes import PredictionInput
from .response_throttler import ResponseThrottler
from .retry_transport import RetryTransport
Expand Down Expand Up @@ -199,7 +200,11 @@ async def sender(response: PredictionResponse, event: WebhookEvent) -> None:
# files

async def upload_file(
self, fh: io.IOBase, *, url: Optional[str], prediction_id: Optional[str]
self,
fh: Union[io.IOBase, OutputURL],
*,
url: Optional[str],
prediction_id: Optional[str],
) -> str:
"""put file to signed endpoint"""
log.debug("upload_file")
Expand All @@ -213,7 +218,7 @@ async def upload_file(

# this code path happens when running outside replicate without upload-url
# in that case we need to return data uris
if url is None:
if url is None and not isinstance(fh, OutputURL):
return file_to_data_uri(fh, content_type)
technillogue marked this conversation as resolved.
Show resolved Hide resolved
assert url

Expand Down Expand Up @@ -243,11 +248,14 @@ async def upload_file(
url = resp1.headers["Location"]

log.info("doing real upload to %s", url)
resp = await self.file_client.put(
url,
content=ChunkFileReader(fh),
headers=headers,
)
if isinstance(fh, OutputURL):
async with self.file_client.stream("GET", fh.url) as resp:
content = resp.aiter_bytes()
resp = await self.file_client.put(url, content=content, headers=headers)
else:
resp = await self.file_client.put(
url, content=ChunkFileReader(fh), headers=headers
)
Comment on lines +253 to +260
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pulling this together. I'm trying to understand this, I'm not very familiar with how async/await works in Python. I understand by going with the httpx async client + iterator we get a non-blocking implementation but how does the existing ChunkFileReader accomplish the same thing?

If it doesn't and we still need to solve that issue we should probably figure out an interface that works for both.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PUT method takes bytes or an async iterator that returns bytes. aiter_bytes counts, as does ChunkFileReader, which is implemented earlier in the same file. ChunkFileReader does do blocking disk reads, but doing so 1MB at a time is likely short enough that we can do all the other networking we need in between.

if you wanted to be fancy you could have a single FileIterator that could take a local or remote URI, but that's kind of annoying to do while holding the context manager for the download request, and this approach is kind of simpler

# TODO: if file size is >1MB, show upload throughput
resp.raise_for_status()

Expand Down Expand Up @@ -292,6 +300,8 @@ async def upload_files(
if isinstance(obj, Path):
with obj.open("rb") as f:
return await self.upload_file(f, url=url, prediction_id=prediction_id)
if isinstance(obj, OutputURL):
return await self.upload_file(obj, url=url, prediction_id=prediction_id)
if isinstance(obj, io.IOBase):
with obj:
return await self.upload_file(obj, url=url, prediction_id=prediction_id)
Expand Down
14 changes: 14 additions & 0 deletions python/cog/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,20 @@ def unlink(self, missing_ok: bool = False) -> None:
raise


class OutputURL:
def __init__(self, url: str, filename: Optional[str] = None) -> None:
parsed = urllib.parse.urlparse(url)
if parsed.scheme not in {"http", "https"}:
raise ValueError(
"OutputURL requires URL to conform to HTTP or HTTPS protocol"
)

if not filename:
filename = os.path.basename(parsed.path)
self.name = filename
self.url = url


class URLFile(io.IOBase):
"""
URLFile is a proxy object for a :class:`urllib3.response.HTTPResponse`
Expand Down
Loading