diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 4b0eda1..f0792a0 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -35,7 +35,7 @@ jobs: - name: install requirements run: | python -m pip install --upgrade pip setuptools - python -m pip install './flox' + python -m pip install '.[all]' python -m pip install safety - name: run safety check run: safety check diff --git a/flox/backends/transfer/proxystore.py b/flox/backends/transfer/proxystore.py index 1b76ba9..1e2ac15 100644 --- a/flox/backends/transfer/proxystore.py +++ b/flox/backends/transfer/proxystore.py @@ -1,3 +1,6 @@ +from typing import cast +from uuid import UUID + from proxystore.connectors.endpoint import EndpointConnector from proxystore.proxy import Proxy from proxystore.store import Store @@ -17,10 +20,10 @@ def __init__(self, flock: Flock, store: str = "endpoint", name: str = "default") ) self.connector = EndpointConnector( - endpoints=[node.proxystore_endpoint for node in flock.nodes()] + endpoints=[cast(UUID, node.proxystore_endpoint) for node in flock.nodes()] ) - store = Store(name=name, connector=self.connector) - self.config = store.config() + store_instance = Store(name=name, connector=self.connector) + self.config = store_instance.config() def report( self, node_state, node_idx, node_kind, state_dict, history diff --git a/flox/data/utils.py b/flox/data/utils.py index cb8987e..3f564af 100644 --- a/flox/data/utils.py +++ b/flox/data/utils.py @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt import numpy as np from matplotlib.axes import Axes -from scipy import stats +from scipy import stats # type: ignore from torch.utils.data import DataLoader, Dataset, Subset from flox.data import FederatedSubsets diff --git a/flox/run/fit_sync.py b/flox/run/fit_sync.py index cb5437d..90e9ce3 100644 --- a/flox/run/fit_sync.py +++ b/flox/run/fit_sync.py @@ -5,6 +5,7 @@ from typing import Literal, TypeAlias import pandas as pd +from torch.utils.data import Dataset, Subset from tqdm import tqdm from flox.backends.launcher import GlobusComputeLauncher, LocalLauncher @@ -139,6 +140,7 @@ def sync_flock_traverse( # If the current node is a worker node, then Launch the LOCAL FITTING job. if flock.get_kind(node) is FlockNodeKind.WORKER: + dataset: Dataset | Subset if isinstance(transfer, ProxyStoreTransfer): dataset = transfer.proxy(datasets[node.idx]) else: diff --git a/pyproject.toml b/pyproject.toml index 5c5ebc8..2232f3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,8 @@ dependencies = [ [project.optional-dependencies] dev = ["black", "coverage", "jupyterlab", "matplotlib", "numpy", "pytest", "seaborn", "tensorboard", "torchvision", "matplotlib-stubs", "pandas-stubs", "networkx-stubs"] monitoring = ["tensorboard"] +proxystore = ["proxystore"] +all = ["flox[dev,monitoring,proxystore]"] [tool.pytest.ini_options] addopts = [ diff --git a/tox.ini b/tox.ini index 1b5067d..583b8c1 100644 --- a/tox.ini +++ b/tox.ini @@ -16,7 +16,9 @@ commands = [testenv:mypy] deps = mypy>=1.6.1 -extras = dev +extras = + dev + proxystore commands = mypy --install-types --non-interactive -p flox {posargs}