diff --git a/flight/engine/data/proxy.py b/flight/engine/data/proxy.py index 03e47ba..f19787a 100644 --- a/flight/engine/data/proxy.py +++ b/flight/engine/data/proxy.py @@ -5,7 +5,7 @@ import cloudpickle from proxystore.connectors.endpoint import EndpointConnector -from proxystore.store import Store +from proxystore.store import Store, get_or_create_store, get_store from ...federation.topologies import Topology from .base import AbstractTransfer @@ -13,9 +13,16 @@ if t.TYPE_CHECKING: from proxystore.proxy import Proxy +T = t.TypeVar("T") + class ProxystoreTransfer(AbstractTransfer): - def __init__(self, topo: Topology, name: str = "default") -> None: + def __init__( + self, + topo: Topology, + evict: bool = False, + name: str = "default", + ) -> None: if not topo.proxystore_ready: raise ValueError( "Flock is not ready to use ProxyStore (i.e., " @@ -30,15 +37,30 @@ def __init__(self, topo: Topology, name: str = "default") -> None: else: endpoints.append(node.proxystore_id) - self.name = name - self.connector = EndpointConnector(endpoints=endpoints) - store = Store( - name=name, - connector=self.connector, - serializer=cloudpickle.dumps, - deserializer=cloudpickle.loads, - ) - self.config = store.config() - - def __call__(self, data: t.Any) -> Proxy[t.Any]: - return Store.from_config(self.config).proxy(data) + store = get_store(name) + if store is None: + store = Store( + name=name, + connector=EndpointConnector(endpoints=endpoints), + # In the future, these could be customized (de)serializers + # that are optimized for Flight/PyTorch models. + serializer=cloudpickle.dumps, + deserializer=cloudpickle.loads, + register=True, + ) + self.evict = evict + self.store = store + + def __call__(self, data: T) -> Proxy[T]: + # evict=True is only safe when it is guarenteed that the proxy will + # only be used by a single consumer. + return self.store.proxy(data, evict=self.evict) + + def __getstate__(self) -> dict[str, t.Any]: + # Customize pickle behavior so that stateful objects are not pickled. + return {"config": self.store.config(), "evict": self.evict} + + def __setstate__(self, state: dict[str, t.Any]) -> None: + # Initialized an object from its pickled state. + self.evict = state["evict"] + self.store = get_or_create_store(state["config"], register=True) diff --git a/pyproject.toml b/pyproject.toml index c1aca6b..47cd36c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,8 @@ dependencies = [ "matplotlib", "numpy", "pandas", - "proxystore", + "parsl", + "proxystore>=0.7.1", "scipy", "scikit-learn", "tqdm", @@ -73,16 +74,31 @@ docs = [ "mkdocstrings-python", ] -[tool.mypy] -plugins = [ - "numpy.typing.mypy_plugin" -] +[tool.coverage.run] +omit = ["*/_remote_module_non_scriptable.py"] -[tool.pytest.ini_options] -addopts = [ - "--import-mode=importlib", +[tool.coverage.report] +show_missing = true +skip_covered = true +exclude_also = [ + # a more strict default pragma + "\\# pragma: no cover\\b", + # allow defensive code + "^\\s*raise AssertionError\\b", + "^\\s*raise NotImplementedError\\b", + "^\\s*return NotImplemented\\b", + "^\\s*raise$", + # typing-related code + "^\\s*if (False|TYPE_CHECKING):", + ": \\.\\.\\.(\\s*#.*)?$", + "^ +\\.\\.\\.$", ] +[tool.mypy] +plugins = [ + "numpy.typing.mypy_plugin", + "proxystore.mypy_plugin", +] [tool.setuptools.packages.find] include = ["flight*"] diff --git a/tests/engine/control/__init__.py b/tests/engine/control/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tox.ini b/tox.ini index fa22b7b..3639291 100644 --- a/tox.ini +++ b/tox.ini @@ -11,7 +11,7 @@ allowlist_externals = coverage commands = coverage erase coverage run -m pytest tests {posargs} - coverage report + coverage report --ignore-errors [testenv:pre-commit] skip_install = true