diff --git a/python/cog/types.py b/python/cog/types.py index 83a3374297..29d868c9e7 100644 --- a/python/cog/types.py +++ b/python/cog/types.py @@ -285,10 +285,9 @@ class URLFile(io.IOBase): URL that can survive pickling/unpickling. """ - __slots__ = ("__target__", "__url__") + __slots__ = ("__target__", "__url__", "name") - def __init__(self, url: str) -> None: - object.__setattr__(self, "__url__", url) + def __init__(self, url: str, filename: Optional[str] = None) -> None: parsed = urllib.parse.urlparse(url) if parsed.scheme not in { "http", @@ -298,13 +297,42 @@ def __init__(self, url: str) -> None: "URLFile requires URL to conform to HTTP or HTTPS protocol" ) object.__setattr__(self, "name", os.path.basename(parsed.path)) + object.__setattr__(self, "__url__", url) + + if parsed.scheme not in { + "http", + "https", + }: + raise ValueError( + "URLFile requires URL to conform to HTTP or HTTPS protocol" + ) + + if not filename: + filename = os.path.basename(parsed.path) + + object.__setattr__(self, "name", filename) + object.__setattr__(self, "__url__", url) + + def __del__(self) -> None: + try: + object.__getattribute__(self, "__target__") + except AttributeError: + # Do nothing when tearing down the object if the response object + # hasn't been created yet. + return + + super().__del__() # We provide __getstate__ and __setstate__ explicitly to ensure that the # object is always picklable. def __getstate__(self) -> Dict[str, Any]: - return {"url": object.__getattribute__(self, "__url__")} + return { + "name": object.__getattribute__(self, "name"), + "url": object.__getattribute__(self, "__url__"), + } def __setstate__(self, state: Dict[str, Any]) -> None: + object.__setattr__(self, "name", state["name"]) object.__setattr__(self, "__url__", state["url"]) # Proxy getattr/setattr/delattr through to the response object. @@ -317,6 +345,8 @@ def __setattr__(self, name: str, value: Any) -> None: def __getattr__(self, name: str) -> Any: if name in ("__target__", "__wrapped__", "__url__"): raise AttributeError(name) + elif name == "name": + return object.__getattribute__(self, "name") return getattr(self.__wrapped__, name) def __delattr__(self, name: str) -> None: diff --git a/python/tests/test_types.py b/python/tests/test_types.py index 95f7b237f8..7bc463a858 100644 --- a/python/tests/test_types.py +++ b/python/tests/test_types.py @@ -15,6 +15,11 @@ def test_urlfile_protocol_validation(): URLFile("data:text/plain,hello") +def test_urlfile_custom_filename(): + u = URLFile("https://example.com/some-path", filename="my_file.txt") + assert u.name == "my_file.txt" + + @responses.activate def test_urlfile_acts_like_response(): responses.get( @@ -61,18 +66,23 @@ def test_urlfile_can_be_pickled(): @responses.activate def test_urlfile_can_be_pickled_even_once_loaded(): - responses.get( + mock = responses.get( "https://example.com/some/url", json={"message": "hello world"}, status=200, ) - u = URLFile("https://example.com/some/url") - u.read() + u = URLFile("https://example.com/some/url", filename="my_file.txt") + assert u.name == "my_file.txt" + assert u.read() == b'{"message": "hello world"}' result = pickle.loads(pickle.dumps(u)) assert isinstance(result, URLFile) + assert result.name == "my_file.txt" + assert result.read() == b'{"message": "hello world"}' + + assert mock.call_count == 2 @pytest.mark.parametrize(