Skip to content

Commit

Permalink
Merge pull request #46 from gpauloski/issue-30
Browse files Browse the repository at this point in the history
Refactor ProxystoreTransfer class
  • Loading branch information
nathaniel-hudson authored Nov 15, 2024
2 parents 4d285e9 + d84853b commit e912660
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 23 deletions.
50 changes: 36 additions & 14 deletions flight/engine/data/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,24 @@

import cloudpickle
from proxystore.connectors.endpoint import EndpointConnector
from proxystore.store import Store
from proxystore.store import Store, get_or_create_store, get_store

from ...federation.topologies import Topology
from .base import AbstractTransfer

if t.TYPE_CHECKING:
from proxystore.proxy import Proxy

T = t.TypeVar("T")


class ProxystoreTransfer(AbstractTransfer):
def __init__(self, topo: Topology, name: str = "default") -> None:
def __init__(
self,
topo: Topology,
evict: bool = False,
name: str = "default",
) -> None:
if not topo.proxystore_ready:
raise ValueError(
"Flock is not ready to use ProxyStore (i.e., "
Expand All @@ -30,15 +37,30 @@ def __init__(self, topo: Topology, name: str = "default") -> None:
else:
endpoints.append(node.proxystore_id)

self.name = name
self.connector = EndpointConnector(endpoints=endpoints)
store = Store(
name=name,
connector=self.connector,
serializer=cloudpickle.dumps,
deserializer=cloudpickle.loads,
)
self.config = store.config()

def __call__(self, data: t.Any) -> Proxy[t.Any]:
return Store.from_config(self.config).proxy(data)
store = get_store(name)
if store is None:
store = Store(
name=name,
connector=EndpointConnector(endpoints=endpoints),
# In the future, these could be customized (de)serializers
# that are optimized for Flight/PyTorch models.
serializer=cloudpickle.dumps,
deserializer=cloudpickle.loads,
register=True,
)
self.evict = evict
self.store = store

def __call__(self, data: T) -> Proxy[T]:
# evict=True is only safe when it is guarenteed that the proxy will
# only be used by a single consumer.
return self.store.proxy(data, evict=self.evict)

def __getstate__(self) -> dict[str, t.Any]:
# Customize pickle behavior so that stateful objects are not pickled.
return {"config": self.store.config(), "evict": self.evict}

def __setstate__(self, state: dict[str, t.Any]) -> None:
# Initialized an object from its pickled state.
self.evict = state["evict"]
self.store = get_or_create_store(state["config"], register=True)
32 changes: 24 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ dependencies = [
"matplotlib",
"numpy",
"pandas",
"proxystore",
"parsl",
"proxystore>=0.7.1",
"scipy",
"scikit-learn",
"tqdm",
Expand Down Expand Up @@ -73,16 +74,31 @@ docs = [
"mkdocstrings-python",
]

[tool.mypy]
plugins = [
"numpy.typing.mypy_plugin"
]
[tool.coverage.run]
omit = ["*/_remote_module_non_scriptable.py"]

[tool.pytest.ini_options]
addopts = [
"--import-mode=importlib",
[tool.coverage.report]
show_missing = true
skip_covered = true
exclude_also = [
# a more strict default pragma
"\\# pragma: no cover\\b",
# allow defensive code
"^\\s*raise AssertionError\\b",
"^\\s*raise NotImplementedError\\b",
"^\\s*return NotImplemented\\b",
"^\\s*raise$",
# typing-related code
"^\\s*if (False|TYPE_CHECKING):",
": \\.\\.\\.(\\s*#.*)?$",
"^ +\\.\\.\\.$",
]

[tool.mypy]
plugins = [
"numpy.typing.mypy_plugin",
"proxystore.mypy_plugin",
]

[tool.setuptools.packages.find]
include = ["flight*"]
Expand Down
Empty file.
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ allowlist_externals = coverage
commands =
coverage erase
coverage run -m pytest tests {posargs}
coverage report
coverage report --ignore-errors

[testenv:pre-commit]
skip_install = true
Expand Down

0 comments on commit e912660

Please sign in to comment.