Skip to content

Commit

Permalink
Improve ProxystoreTransfer class
Browse files Browse the repository at this point in the history
- Reuse stateful Store object between calls to .proxy()
- Do not reinitialize Store if already created in that process
- Improve typing on .proxy()
- Customize pickle behavior to ignore stateful attributes
- Add evict setting for proxies
- Add note on supporting customized serializers in future
- Bump ProxyStore to 0.7.1 and later (should be okay now since flight is
  Python 3.11 and later and globus compute supports pydantic 2 now)
  • Loading branch information
gpauloski committed Oct 31, 2024
1 parent e50a80c commit d84853b
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 15 deletions.
50 changes: 36 additions & 14 deletions flight/engine/data/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,24 @@

import cloudpickle
from proxystore.connectors.endpoint import EndpointConnector
from proxystore.store import Store
from proxystore.store import Store, get_or_create_store, get_store

from ...federation.topologies import Topology
from .base import AbstractTransfer

if t.TYPE_CHECKING:
from proxystore.proxy import Proxy

T = t.TypeVar("T")


class ProxystoreTransfer(AbstractTransfer):
def __init__(self, topo: Topology, name: str = "default") -> None:
def __init__(
self,
topo: Topology,
evict: bool = False,
name: str = "default",
) -> None:
if not topo.proxystore_ready:
raise ValueError(
"Flock is not ready to use ProxyStore (i.e., "
Expand All @@ -30,15 +37,30 @@ def __init__(self, topo: Topology, name: str = "default") -> None:
else:
endpoints.append(node.proxystore_id)

self.name = name
self.connector = EndpointConnector(endpoints=endpoints)
store = Store(
name=name,
connector=self.connector,
serializer=cloudpickle.dumps,
deserializer=cloudpickle.loads,
)
self.config = store.config()

def __call__(self, data: t.Any) -> Proxy[t.Any]:
return Store.from_config(self.config).proxy(data)
store = get_store(name)
if store is None:
store = Store(
name=name,
connector=EndpointConnector(endpoints=endpoints),
# In the future, these could be customized (de)serializers
# that are optimized for Flight/PyTorch models.
serializer=cloudpickle.dumps,
deserializer=cloudpickle.loads,
register=True,
)
self.evict = evict
self.store = store

def __call__(self, data: T) -> Proxy[T]:
# evict=True is only safe when it is guarenteed that the proxy will
# only be used by a single consumer.
return self.store.proxy(data, evict=self.evict)

def __getstate__(self) -> dict[str, t.Any]:
# Customize pickle behavior so that stateful objects are not pickled.
return {"config": self.store.config(), "evict": self.evict}

def __setstate__(self, state: dict[str, t.Any]) -> None:
# Initialized an object from its pickled state.
self.evict = state["evict"]
self.store = get_or_create_store(state["config"], register=True)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dependencies = [
"numpy",
"pandas",
"parsl",
"proxystore",
"proxystore>=0.7.1",
"scipy",
"scikit-learn",
"tqdm",
Expand Down

0 comments on commit d84853b

Please sign in to comment.