Skip to content

Commit

Permalink
Support custom filename to be provided to URLFile
Browse files Browse the repository at this point in the history
This commit works around an issue where the basename of the URL many not
actually contain a file extension and the uploader logic cannot infer
the mime type for the file.

We stash the name when pickling and extract it again when unpickling.
The __getattr__ function then supports returning the underlying name
value rather than proxying to the underlying request object.

I also ran into a small bug whereby the __del__ method was triggering
a network request because of some private attributes being accessed
during teardown would trigger the __wrapper__ code. I've overridden
the super class to disable this. Though I'm unclear if this is just the
test suite doing this cleanup.
  • Loading branch information
aron committed Oct 18, 2024
1 parent 1aa30df commit 4138b82
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 8 deletions.
36 changes: 32 additions & 4 deletions python/cog/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,9 @@ class URLFile(io.IOBase):
URL that can survive pickling/unpickling.
"""

__slots__ = ("__target__", "__url__")
__slots__ = ("__target__", "__url__", "name")

def __init__(self, url: str) -> None:
def __init__(self, url: str, filename: Optional[str] = None) -> None:
parsed = urllib.parse.urlparse(url)
if parsed.scheme not in {
"http",
Expand All @@ -296,15 +296,41 @@ def __init__(self, url: str) -> None:
raise ValueError(
"URLFile requires URL to conform to HTTP or HTTPS protocol"
)
object.__setattr__(self, "name", os.path.basename(parsed.path))

if parsed.scheme not in {
"http",
"https",
}:
raise ValueError(
"URLFile requires URL to conform to HTTP or HTTPS protocol"
)

if not filename:
filename = os.path.basename(parsed.path)

object.__setattr__(self, "name", filename)
object.__setattr__(self, "__url__", url)

def __del__(self) -> None:
try:
object.__getattribute__(self, "__target__")
except AttributeError:
# Do nothing when tearing down the object if the response object
# hasn't been created yet.
return

super().__del__()

# We provide __getstate__ and __setstate__ explicitly to ensure that the
# object is always picklable.
def __getstate__(self) -> Dict[str, Any]:
return {"url": object.__getattribute__(self, "__url__")}
return {
"name": object.__getattribute__(self, "name"),
"url": object.__getattribute__(self, "__url__"),
}

def __setstate__(self, state: Dict[str, Any]) -> None:
object.__setattr__(self, "name", state["name"])
object.__setattr__(self, "__url__", state["url"])

# Proxy getattr/setattr/delattr through to the response object.
Expand All @@ -317,6 +343,8 @@ def __setattr__(self, name: str, value: Any) -> None:
def __getattr__(self, name: str) -> Any:
if name in ("__target__", "__wrapped__", "__url__"):
raise AttributeError(name)
elif name == "name":
return object.__getattribute__(self, "name")
return getattr(self.__wrapped__, name)

def __delattr__(self, name: str) -> None:
Expand Down
16 changes: 13 additions & 3 deletions python/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ def test_urlfile_protocol_validation():
URLFile("data:text/plain,hello")


def test_urlfile_custom_filename():
u = URLFile("https://example.com/some-path", filename="my_file.txt")
assert u.name == "my_file.txt"


@responses.activate
def test_urlfile_acts_like_response():
responses.get(
Expand Down Expand Up @@ -61,18 +66,23 @@ def test_urlfile_can_be_pickled():

@responses.activate
def test_urlfile_can_be_pickled_even_once_loaded():
responses.get(
mock = responses.get(
"https://example.com/some/url",
json={"message": "hello world"},
status=200,
)

u = URLFile("https://example.com/some/url")
u.read()
u = URLFile("https://example.com/some/url", filename="my_file.txt")
assert u.name == "my_file.txt"
assert u.read() == b'{"message": "hello world"}'

result = pickle.loads(pickle.dumps(u))

assert isinstance(result, URLFile)
assert result.name == "my_file.txt"
assert result.read() == b'{"message": "hello world"}'

assert mock.call_count == 2


@pytest.mark.parametrize(
Expand Down
1 change: 0 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ source =
.tox/*/lib/python*/site-packages/cog

[pytest]
addopts = --timeout=20

[testenv]
package = wheel
Expand Down

0 comments on commit 4138b82

Please sign in to comment.