diff --git a/flight/engine/data/proxy.py b/flight/engine/data/proxy.py index 03e47ba..f19787a 100644 --- a/flight/engine/data/proxy.py +++ b/flight/engine/data/proxy.py @@ -5,7 +5,7 @@ import cloudpickle from proxystore.connectors.endpoint import EndpointConnector -from proxystore.store import Store +from proxystore.store import Store, get_or_create_store, get_store from ...federation.topologies import Topology from .base import AbstractTransfer @@ -13,9 +13,16 @@ if t.TYPE_CHECKING: from proxystore.proxy import Proxy +T = t.TypeVar("T") + class ProxystoreTransfer(AbstractTransfer): - def __init__(self, topo: Topology, name: str = "default") -> None: + def __init__( + self, + topo: Topology, + evict: bool = False, + name: str = "default", + ) -> None: if not topo.proxystore_ready: raise ValueError( "Flock is not ready to use ProxyStore (i.e., " @@ -30,15 +37,30 @@ def __init__(self, topo: Topology, name: str = "default") -> None: else: endpoints.append(node.proxystore_id) - self.name = name - self.connector = EndpointConnector(endpoints=endpoints) - store = Store( - name=name, - connector=self.connector, - serializer=cloudpickle.dumps, - deserializer=cloudpickle.loads, - ) - self.config = store.config() - - def __call__(self, data: t.Any) -> Proxy[t.Any]: - return Store.from_config(self.config).proxy(data) + store = get_store(name) + if store is None: + store = Store( + name=name, + connector=EndpointConnector(endpoints=endpoints), + # In the future, these could be customized (de)serializers + # that are optimized for Flight/PyTorch models. + serializer=cloudpickle.dumps, + deserializer=cloudpickle.loads, + register=True, + ) + self.evict = evict + self.store = store + + def __call__(self, data: T) -> Proxy[T]: + # evict=True is only safe when it is guarenteed that the proxy will + # only be used by a single consumer. + return self.store.proxy(data, evict=self.evict) + + def __getstate__(self) -> dict[str, t.Any]: + # Customize pickle behavior so that stateful objects are not pickled. + return {"config": self.store.config(), "evict": self.evict} + + def __setstate__(self, state: dict[str, t.Any]) -> None: + # Initialized an object from its pickled state. + self.evict = state["evict"] + self.store = get_or_create_store(state["config"], register=True) diff --git a/pyproject.toml b/pyproject.toml index 30e1d33..47cd36c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ "numpy", "pandas", "parsl", - "proxystore", + "proxystore>=0.7.1", "scipy", "scikit-learn", "tqdm",