Skip to content

Commit

Permalink
Merge pull request #47 from scalableminds/fix-pickle
Browse files Browse the repository at this point in the history
Fix serialization of UPath
  • Loading branch information
andrewfulton9 authored Mar 18, 2022
2 parents 747cc18 + 2cfc277 commit df97e1e
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 12 deletions.
2 changes: 1 addition & 1 deletion noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def install(session):
@nox.session(python=False)
def smoke(session):
session.install(*"pytest aiohttp requests gcsfs".split())
session.run(*"pytest --skiphdfs upath".split())
session.run(*"pytest --skiphdfs -vv upath".split())


@nox.session(python=False)
Expand Down
39 changes: 39 additions & 0 deletions upath/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,42 @@ def _from_parsed_parts(self, drv, root, parts, init=True):
if init:
obj._init(**self._kwargs)
return obj

def __truediv__(self, key):
# Add `/` root if not present
if len(self._parts) == 0:
key = f"{self._root}{key}"

# Adapted from `PurePath._make_child`
drv, root, parts = self._parse_args((key,))
drv, root, parts = self._flavour.join_parsed_parts(
self._drv, self._root, self._parts, drv, root, parts
)

kwargs = self._kwargs.copy()
kwargs.pop("_url")

# Create a new object
out = self.__class__(
self._format_parsed_parts(drv, root, parts),
**kwargs,
)
return out

def __setstate__(self, state):
kwargs = state["_kwargs"].copy()
kwargs["_url"] = self._url
self._kwargs = kwargs
# _init needs to be called again, because when __new__ called _init,
# the _kwargs were not yet set
self._init()

def __reduce__(self):
kwargs = self._kwargs.copy()
kwargs.pop("_url", None)

return (
self.__class__,
(self._format_parsed_parts(self._drv, self._root, self._parts),),
{"_kwargs": kwargs},
)
47 changes: 39 additions & 8 deletions upath/tests/cases.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import pickle
import sys
from pathlib import Path

import pytest

from upath import UPath


Expand Down Expand Up @@ -37,15 +38,14 @@ def test_glob(self, pathlib_base):
mock_glob = list(self.path.glob("**.txt"))
path_glob = list(pathlib_base.glob("**/*.txt"))

assert len(mock_glob) == len(path_glob)
assert all(
map(
lambda m: m.path
in [str(p).replace("\\", "/") for p in path_glob],
mock_glob,
)
root = "/" if sys.platform.startswith("win") else ""
mock_glob_normalized = sorted([a.path for a in mock_glob])
path_glob_normalized = sorted(
[f"{root}{a}".replace("\\", "/") for a in path_glob]
)

assert mock_glob_normalized == path_glob_normalized

def test_group(self):
with pytest.raises(NotImplementedError):
self.path.group()
Expand Down Expand Up @@ -228,3 +228,34 @@ def test_fsspec_compat(self):
upath2 = UPath(p2)
assert upath2.read_bytes() == content
upath2.unlink()

def test_pickling(self):
path = self.path
pickled_path = pickle.dumps(path)
recovered_path = pickle.loads(pickled_path)

assert type(path) == type(recovered_path)
assert str(path) == str(recovered_path)
assert path.fs.storage_options == recovered_path.fs.storage_options

def test_pickling_child_path(self):
path = (self.path) / "subfolder" / "subsubfolder"
pickled_path = pickle.dumps(path)
recovered_path = pickle.loads(pickled_path)

assert type(path) == type(recovered_path)
assert str(path) == str(recovered_path)
assert path._drv == recovered_path._drv
assert path._root == recovered_path._root
assert path._parts == recovered_path._parts
assert path.fs.storage_options == recovered_path.fs.storage_options

def test_child_path(self):
path_a = UPath(f"{self.path}/folder")
path_b = self.path / "folder"

assert str(path_a) == str(path_b)
assert path_a._root == path_b._root
assert path_a._drv == path_b._drv
assert path_a._parts == path_b._parts
assert path_a._url == path_b._url
45 changes: 42 additions & 3 deletions upath/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import sys
import pathlib
import pickle
import sys
import warnings

import pytest

from upath import UPath
from upath.implementations.s3 import S3Path
from upath.tests.cases import BaseTests
Expand Down Expand Up @@ -38,7 +38,12 @@ class TestUpath(BaseTests):
def path(self, local_testdir):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
self.path = UPath(f"mock:{local_testdir}")

# On Windows the path needs to be prefixed with `/`, becaue
# `UPath` implements `_posix_flavour`, which requires a `/` root
# in order to correctly deserialize pickled objects
root = "/" if sys.platform.startswith("win") else ""
self.path = UPath(f"mock:{root}{local_testdir}")

def test_fsspec_compat(self):
pass
Expand Down Expand Up @@ -137,3 +142,37 @@ def test_create_from_type(path, storage_options, module, object_type):
except (ImportError, ModuleNotFoundError):
# fs failed to import
pass


def test_child_path():
path_a = UPath("gcs://bucket/folder")
path_b = UPath("gcs://bucket") / "folder"

assert str(path_a) == str(path_b)
assert path_a._root == path_b._root
assert path_a._drv == path_b._drv
assert path_a._parts == path_b._parts
assert path_a._url == path_b._url


def test_pickling():
path = UPath("gcs://bucket/folder", storage_options={"anon": True})
pickled_path = pickle.dumps(path)
recovered_path = pickle.loads(pickled_path)

assert type(path) == type(recovered_path)
assert str(path) == str(recovered_path)
assert path.fs.storage_options == recovered_path.fs.storage_options


def test_pickling_child_path():
path = UPath("gcs://bucket", anon=True) / "subfolder" / "subsubfolder"
pickled_path = pickle.dumps(path)
recovered_path = pickle.loads(pickled_path)

assert type(path) == type(recovered_path)
assert str(path) == str(recovered_path)
assert path._drv == recovered_path._drv
assert path._root == recovered_path._root
assert path._parts == recovered_path._parts
assert path.fs.storage_options == recovered_path.fs.storage_options

0 comments on commit df97e1e

Please sign in to comment.