From a7b298d42c1de7bc28c3da73e5c21907cb89e4f0 Mon Sep 17 00:00:00 2001 From: Chris Janidlo Date: Fri, 23 Feb 2024 15:45:53 -0600 Subject: [PATCH] fix mypy (wip) --- flox/backends/launcher/impl_base.py | 10 +++++++- flox/backends/launcher/impl_globus.py | 8 ++---- flox/backends/launcher/impl_local.py | 9 ++++--- flox/backends/launcher/impl_parsl.py | 10 +++++--- flox/data/__init__.py | 4 +-- flox/data/core.py | 20 +++++++++++---- flox/data/utils.py | 19 ++++++++------ flox/flock/flock.py | 15 +++++------ flox/flock/node.py | 4 +-- flox/nn/logger/base.py | 2 +- flox/nn/model.py | 6 +++-- flox/nn/trainer.py | 7 ++++-- flox/run/fit.py | 2 +- flox/run/fit_async.py | 28 +++++++++++---------- flox/run/fit_sync.py | 9 +++---- flox/run/jobs.py | 8 +++--- flox/strategies/base.py | 19 ++++++++------ flox/strategies/commons/averaging.py | 8 +++--- flox/strategies/commons/worker_selection.py | 24 ++++++++++++------ flox/strategies/registry/fedavg.py | 12 +++++---- flox/strategies/registry/fedprox.py | 4 ++- flox/strategies/registry/fedsgd.py | 18 ++++++++----- flox/utils/random/flock.py | 3 ++- pyproject.toml | 2 +- tox.ini | 1 + 25 files changed, 153 insertions(+), 99 deletions(-) diff --git a/flox/backends/launcher/impl_base.py b/flox/backends/launcher/impl_base.py index 511d056..187ac54 100644 --- a/flox/backends/launcher/impl_base.py +++ b/flox/backends/launcher/impl_base.py @@ -1,9 +1,15 @@ from abc import ABC, abstractmethod from concurrent.futures import Future +from typing import Any, Protocol from flox.flock import FlockNode +class LauncherFunction(Protocol): + def __call__(self, node: FlockNode, *args: Any, **kwargs: Any) -> Any: + ... + + class Launcher(ABC): """ Base class for launching functions in an FL process. @@ -14,7 +20,9 @@ def __init__(self): pass @abstractmethod - def submit(self, fn, node: FlockNode, /, *args, **kwargs) -> Future: + def submit( + self, fn: LauncherFunction, node: FlockNode, /, *args, **kwargs + ) -> Future: raise NotImplementedError() @abstractmethod diff --git a/flox/backends/launcher/impl_globus.py b/flox/backends/launcher/impl_globus.py index 3eecb48..9573b71 100644 --- a/flox/backends/launcher/impl_globus.py +++ b/flox/backends/launcher/impl_globus.py @@ -1,10 +1,8 @@ -from collections.abc import Callable from concurrent.futures import Future -from typing import Any import globus_compute_sdk -from flox.backends.launcher.impl_base import Launcher +from flox.backends.launcher.impl_base import Launcher, LauncherFunction from flox.flock import FlockNode @@ -13,15 +11,13 @@ class GlobusComputeLauncher(Launcher): Class that executes tasks on Globus Compute. """ - _globus_compute_executor: globus_compute_sdk.Executor | None = None - def __init__(self): super().__init__() if self._globus_compute_executor is None: self._globus_compute_executor = globus_compute_sdk.Executor() def submit( - self, fn: Callable[[FlockNode, ...], Any], node: FlockNode, /, *args, **kwargs + self, fn: LauncherFunction, node: FlockNode, /, *args, **kwargs ) -> Future: endpoint_id = node.globus_compute_endpoint self._globus_compute_executor.endpoint_id = endpoint_id diff --git a/flox/backends/launcher/impl_local.py b/flox/backends/launcher/impl_local.py index 2aa0139..507a314 100644 --- a/flox/backends/launcher/impl_local.py +++ b/flox/backends/launcher/impl_local.py @@ -1,6 +1,6 @@ -from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor +from concurrent.futures import Executor, Future, ProcessPoolExecutor, ThreadPoolExecutor -from flox.backends.launcher.impl_base import Launcher +from flox.backends.launcher.impl_base import Launcher, LauncherFunction from flox.flock import FlockNode @@ -12,6 +12,7 @@ class LocalLauncher(Launcher): def __init__(self, pool: str, n_workers: int = 1): super().__init__() self.n_workers = n_workers + self.pool: Executor if pool == "process": self.pool = ProcessPoolExecutor(n_workers) elif pool == "thread": @@ -21,7 +22,9 @@ def __init__(self, pool: str, n_workers: int = 1): "Illegal value for argument `pool`. Must be either 'pool' or 'thread'." ) - def submit(self, fn, node: FlockNode, /, *args, **kwargs) -> Future: + def submit( + self, fn: LauncherFunction, node: FlockNode, /, *args, **kwargs + ) -> Future: return self.pool.submit(fn, node, *args, **kwargs) def collect(self): diff --git a/flox/backends/launcher/impl_parsl.py b/flox/backends/launcher/impl_parsl.py index 227b4b5..e7287af 100644 --- a/flox/backends/launcher/impl_parsl.py +++ b/flox/backends/launcher/impl_parsl.py @@ -1,6 +1,6 @@ from concurrent.futures import Future -from flox.backends.launcher.impl_base import Launcher +from flox.backends.launcher.impl_base import Launcher, LauncherFunction from flox.flock import FlockNode @@ -13,8 +13,10 @@ def __init__(self): super().__init__() raise NotImplementedError(f"{self.__name__} yet implemented") - def submit(self, fn, node: FlockNode, /, *args, **kwargs) -> Future: - pass + def submit( + self, fn: LauncherFunction, node: FlockNode, /, *args, **kwargs + ) -> Future: + raise NotImplementedError() def collect(self): - pass + raise NotImplementedError() diff --git a/flox/data/__init__.py b/flox/data/__init__.py index 5bf644c..63cad44 100644 --- a/flox/data/__init__.py +++ b/flox/data/__init__.py @@ -61,7 +61,7 @@ FLoX includes utility functions to simplify the conversion from a standard, centralized PyTorch dataset to a simulated, decentralized dataset. """ -from flox.data.core import FloxDataset +from flox.data.core import FederatedSubsets, FloxDataset from flox.data.utils import fed_barplot, federated_split -__all__ = ["FloxDataset", "fed_barplot", "federated_split"] +__all__ = ["FloxDataset", "FederatedSubsets", "fed_barplot", "federated_split"] diff --git a/flox/data/core.py b/flox/data/core.py index b7a7b40..26ec0b1 100644 --- a/flox/data/core.py +++ b/flox/data/core.py @@ -1,5 +1,6 @@ +from collections.abc import Mapping from enum import IntEnum, auto -from typing import NewType, TypeVar, Union +from typing import NewType, Union, get_args from torch.utils.data import Dataset, Subset @@ -16,11 +17,21 @@ class FloxDatasetKind(IntEnum): def from_obj(obj) -> "FloxDatasetKind": if isinstance(obj, Dataset): return FloxDatasetKind.STANDARD - elif isinstance(obj, FederatedSubsets): + elif FloxDatasetKind.is_federated_dataset(obj): return FloxDatasetKind.FEDERATED else: return FloxDatasetKind.INVALID + @staticmethod + def is_federated_dataset(obj) -> bool: + if not isinstance(obj, Mapping): + return False + + return all( + isinstance(k, get_args(FlockNodeID)) and isinstance(v, (Dataset, Subset)) + for k, v in obj.items() + ) + def flox_compatible_data(obj) -> bool: kind = FloxDatasetKind.from_obj(obj) @@ -29,9 +40,8 @@ def flox_compatible_data(obj) -> bool: return True -T_co = TypeVar("T_co", covariant=True) FederatedSubsets = NewType( - "FederatedSubsets", dict[FlockNodeID, Union[Dataset[T_co], Subset[T_co]]] + "FederatedSubsets", Mapping[FlockNodeID, Union[Dataset, Subset]] ) @@ -41,4 +51,4 @@ def __init__(self, state: NodeState, /, *args, **kwargs): self.state = state -FloxDataset = NewType("FloxDataset", Union[MyFloxDataset, FederatedSubsets]) +FloxDataset = Union[MyFloxDataset, FederatedSubsets] diff --git a/flox/data/utils.py b/flox/data/utils.py index 551cf25..cb8987e 100644 --- a/flox/data/utils.py +++ b/flox/data/utils.py @@ -1,5 +1,5 @@ import warnings -from collections import defaultdict +from collections import Counter, defaultdict import matplotlib.pyplot as plt import numpy as np @@ -9,6 +9,7 @@ from flox.data import FederatedSubsets from flox.flock import Flock +from flox.flock.node import FlockNodeID # TODO: Implement something similar for regression-based data. @@ -59,16 +60,20 @@ def federated_split( sample_distr = stats.dirichlet(np.full(num_workers, samples_alpha)) label_distr = stats.dirichlet(np.full(num_classes, labels_alpha)) - num_samples_for_workers = (sample_distr.rvs()[0] * len(data)).astype(int) + # pytorch intentionally doesn't define an empty __len__ for DataSet, even though + # most subclasses implement it + data_count = len(data) # type: ignore + + num_samples_for_workers = (sample_distr.rvs()[0] * data_count).astype(int) num_samples_for_workers = { worker.idx: num_samples for worker, num_samples in zip(flock.workers, num_samples_for_workers) } label_probs = {w.idx: label_distr.rvs()[0] for w in flock.workers} - indices: dict[int, list[int]] = defaultdict(list) + indices: dict[FlockNodeID, list[int]] = defaultdict(list) loader = DataLoader(data, batch_size=1) - worker_samples = defaultdict(int) + worker_samples: Counter[FlockNodeID] = Counter() for idx, batch in enumerate(loader): _, y = batch label = y.item() @@ -89,11 +94,11 @@ def federated_split( ) raise err - probs = np.array(probs) - probs = probs / probs.sum() + probs_norm = np.array(probs) + probs_norm = probs_norm / probs_norm.sum() if len(temp_workers) > 0: - chosen_worker = np.random.choice(temp_workers, p=probs) + chosen_worker = np.random.choice(temp_workers, p=probs_norm) indices[chosen_worker].append(idx) worker_samples[chosen_worker] += 1 diff --git a/flox/flock/flock.py b/flox/flock/flock.py index 5f73f0d..ce2931f 100644 --- a/flox/flock/flock.py +++ b/flox/flock/flock.py @@ -2,7 +2,7 @@ import functools import json -from collections.abc import Generator +from collections.abc import Iterator from pathlib import Path from typing import Any from uuid import UUID @@ -55,7 +55,6 @@ def __init__(self, topo: nx.DiGraph | None = None, _src: Path | str | None = Non """ self.node_counter: int = 0 self._src = "interactive" if _src is None else _src - self.leader = None if topo is None: # By default (i.e., `topo is None`), @@ -84,6 +83,8 @@ def __init__(self, topo: nx.DiGraph | None = None, _src: Path | str | None = Non raise ValueError( "A legal Flock cannot have more than one leader." ) + if not found_leader: + raise ValueError("A legal Flock must have a leader.") def add_node( self, @@ -102,7 +103,7 @@ def add_node( proxystore_endpoint_id=proxystore_endpoint_id, ) self.node_counter += 1 - return FlockNodeID(idx) + return idx def add_edge(self, u: FlockNodeID, v: FlockNodeID, **attrs) -> None: """ @@ -218,7 +219,7 @@ def validate_topo(self) -> bool: return True - def children(self, node: FlockNode | FlockNodeID | int) -> Generator[FlockNode]: + def children(self, node: FlockNode | FlockNodeID | int) -> Iterator[FlockNode]: if isinstance(node, FlockNode): idx = node.idx else: @@ -384,7 +385,7 @@ def proxystore_ready(self) -> bool: # return self.nodes(by_kind=FlockNodeKind.LEADER) @property - def aggregators(self) -> Generator[FlockNode]: + def aggregators(self) -> Iterator[FlockNode]: """ The aggregator nodes of the Flock. @@ -394,7 +395,7 @@ def aggregators(self) -> Generator[FlockNode]: return self.nodes(by_kind=FlockNodeKind.AGGREGATOR) @property - def workers(self) -> Generator[FlockNode]: + def workers(self) -> Iterator[FlockNode]: """ The worker nodes of the Flock. @@ -413,7 +414,7 @@ def number_of_workers(self) -> int: """The number of worker nodes in the Flock.""" return len(list(self.workers)) - def nodes(self, by_kind: FlockNodeKind | None = None) -> Generator[FlockNode]: + def nodes(self, by_kind: FlockNodeKind | None = None) -> Iterator[FlockNode]: for idx, data in self.topo.nodes(data=True): if by_kind is not None and data["kind"] != by_kind: continue diff --git a/flox/flock/node.py b/flox/flock/node.py index 42547ab..28f000d 100644 --- a/flox/flock/node.py +++ b/flox/flock/node.py @@ -1,9 +1,9 @@ from dataclasses import dataclass from enum import Enum, auto -from typing import NewType, Union +from typing import Union from uuid import UUID -FlockNodeID = NewType("FlockNodeID", Union[int, str]) +FlockNodeID = Union[int, str] class FlockNodeKind(Enum): diff --git a/flox/nn/logger/base.py b/flox/nn/logger/base.py index d9338d6..a6f211f 100644 --- a/flox/nn/logger/base.py +++ b/flox/nn/logger/base.py @@ -19,7 +19,7 @@ def clear(self) -> None: def to_pandas(self) -> pd.DataFrame: df = pd.DataFrame.from_records(self.records) - for col in df: + for col in df.columns: if "time" in col: df[col] = pd.to_datetime(df[col]) return df diff --git a/flox/nn/model.py b/flox/nn/model.py index 62e5925..cb4f99e 100644 --- a/flox/nn/model.py +++ b/flox/nn/model.py @@ -1,9 +1,9 @@ -from __future__ import annotations +from abc import ABC, abstractmethod import torch -class FloxModule(torch.nn.Module): +class FloxModule(torch.nn.Module, ABC): """ The ``FloxModule`` is a wrapper for the standard ``torch.nn.Module`` class from PyTorch, with a lot of inspiration from the ``lightning.LightningModule`` class from PyTorch Lightning. @@ -12,6 +12,7 @@ class FloxModule(torch.nn.Module): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + @abstractmethod def training_step( self, batch: torch.Tensor | tuple[torch.Tensor, torch.Tensor], batch_idx: int ) -> torch.Tensor: @@ -26,6 +27,7 @@ def training_step( Loss from the training step. """ + @abstractmethod def configure_optimizers(self) -> torch.optim.Optimizer: """Configures, initializes, and returns the optimizer used to train the model. diff --git a/flox/nn/trainer.py b/flox/nn/trainer.py index 2d7bb62..a1e679a 100644 --- a/flox/nn/trainer.py +++ b/flox/nn/trainer.py @@ -50,10 +50,13 @@ def fit( loss.backward() try: + assert strategy is not None + assert node_state is not None strategy.wrk_on_after_train_step(node_state, loss) - except NotImplementedError: + except (AttributeError, AssertionError): """ - The current strategy does not override the `wrk_on_after_train_step()` callback. + node_state is None, strategy is None, or the strategy doesn't + implement `wrk_on_after_train_step()`. """ pass diff --git a/flox/run/fit.py b/flox/run/fit.py index 20e5b59..fff7ad1 100644 --- a/flox/run/fit.py +++ b/flox/run/fit.py @@ -61,7 +61,7 @@ def federated_fit( """ launcher_cfg = dict() if launcher_cfg is None else launcher_cfg - launcher = create_launcher(launcher, **launcher_cfg) + # launcher = create_launcher(launcher, **launcher_cfg) # not used strategy = "fedsgd" if strategy is None else strategy diff --git a/flox/run/fit_async.py b/flox/run/fit_async.py index a24c7aa..2eb605b 100644 --- a/flox/run/fit_async.py +++ b/flox/run/fit_async.py @@ -1,12 +1,14 @@ from __future__ import annotations -from collections import defaultdict +from collections import Counter from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait import pandas as pd +from flox.backends.transfer.base import BaseTransfer from flox.data import FloxDataset from flox.flock import Flock +from flox.flock.node import FlockNodeID from flox.nn import FloxModule from flox.run.jobs import local_training_job from flox.strategies import Strategy @@ -18,7 +20,6 @@ def async_federated_fit( datasets: FloxDataset, num_global_rounds: int, strategy: Strategy | str = "fedavg", - executor: str = "thread", max_workers: int = 1, ) -> pd.DataFrame: """ @@ -39,22 +40,24 @@ def async_federated_fit( # assert that the flock is a 2-tier system with no intermediary aggregators. executor = ThreadPoolExecutor(max_workers=max_workers) global_module = module_cls() - hyper_params = {} - futures = [ + if isinstance(strategy, str): + strategy = Strategy.get_strategy(strategy)() + + futures = { executor.submit( local_training_job, node, + BaseTransfer(), parent=flock.leader, strategy=strategy, module_cls=module_cls, module_state_dict=global_module.state_dict(), dataset=datasets[node.idx], - **hyper_params, ) for node in flock.workers - ] - num_local_fitins = defaultdict(int) + } + num_local_fitins: Counter[FlockNodeID] = Counter() while futures: done, futures = wait(futures, return_when=FIRST_COMPLETED) @@ -67,21 +70,20 @@ def async_federated_fit( results = [d.result() for d in done] for res in results: - futures = list(futures) - num_local_fitins[res.idx] += 1 - node = flock[res.idx] + num_local_fitins[res.node_idx] += 1 + node = flock[res.node_idx] - if num_local_fitins[res.idx] < num_global_rounds: + if num_local_fitins[res.node_idx] < num_global_rounds: fut = executor.submit( local_training_job, node, + BaseTransfer(), parent=flock.leader, strategy=strategy, module_cls=module_cls, module_state_dict=global_module.state_dict(), dataset=datasets[node.idx], - **hyper_params, ) - futures.append(fut) + futures.add(fut) return pd.DataFrame.from_dict({}) diff --git a/flox/run/fit_sync.py b/flox/run/fit_sync.py index 3a849b9..cb5437d 100644 --- a/flox/run/fit_sync.py +++ b/flox/run/fit_sync.py @@ -59,11 +59,12 @@ def sync_federated_fit( Results from the FL process. """ transfer: Transfer + launcher_instance: Launcher if launcher == "thread" or launcher == "process": - launcher = LocalLauncher(launcher, max_workers) + launcher_instance = LocalLauncher(launcher, max_workers) elif launcher == "globus_compute": - launcher = GlobusComputeLauncher() + launcher_instance = GlobusComputeLauncher() if where == "local": transfer = BaseTransfer() @@ -81,7 +82,7 @@ def sync_federated_fit( # Launch the tasks recursively starting with the aggregation task on the # leader of the Flock. rnd_future = sync_flock_traverse( - launcher, + launcher_instance, transfer=transfer, flock=flock, node=flock.leader, @@ -143,7 +144,6 @@ def sync_flock_traverse( else: dataset = datasets[node.idx] - hyper_params = {} return launcher.submit( local_training_job, node, @@ -153,7 +153,6 @@ def sync_flock_traverse( module_cls=module_cls, module_state_dict=module_state_dict, dataset=dataset, - **hyper_params, ) if isinstance(transfer, ProxyStoreTransfer): diff --git a/flox/run/jobs.py b/flox/run/jobs.py index 8c39f56..8337842 100644 --- a/flox/run/jobs.py +++ b/flox/run/jobs.py @@ -11,17 +11,15 @@ from flox.strategies import Strategy from flox.typing import StateDict -Transfer: BaseTransfer - def local_training_job( node: FlockNode, - transfer: Transfer, + transfer: BaseTransfer, parent: FlockNode, strategy: Strategy, module_cls: type[FloxModule], module_state_dict: StateDict, - dataset: Dataset | Subset | None = None, + dataset: Dataset | Subset, **train_hyper_params, ) -> Result: """Perform local training on a worker node. @@ -79,7 +77,7 @@ def local_training_job( def aggregation_job( - node: FlockNode, transfer: Transfer, strategy: Strategy, results: list[Result] + node: FlockNode, transfer: BaseTransfer, strategy: Strategy, results: list[Result] ) -> Result: """Aggregate the state dicts from each of the results. diff --git a/flox/strategies/base.py b/flox/strategies/base.py index 74e9f4c..1788a17 100644 --- a/flox/strategies/base.py +++ b/flox/strategies/base.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from collections.abc import Iterable, Mapping from typing import TypeAlias @@ -10,7 +11,7 @@ Loss: TypeAlias = torch.Tensor -class Strategy: +class Strategy(ABC): """Base class for the logical blocks of a FL process. A ``Strategy`` in FLoX is used to implement the logic of an FL process. A ``Strategy`` provides @@ -22,10 +23,10 @@ class Strategy: they are run in an FL process. """ - registry = {} + registry: dict[str, type["Strategy"]] = {} @classmethod - def get_strategy(cls, name: str): + def get_strategy(cls, name: str) -> type["Strategy"]: """ Args: @@ -38,7 +39,7 @@ def get_strategy(cls, name: str): if name in cls.registry: return cls.registry[name] else: - raise KeyError(f"Strategy name ({name=}) is not in the Strategy registry.") + raise KeyError(f"Strategy name ({name}) is not in the Strategy registry.") def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) @@ -55,8 +56,9 @@ def agg_before_round(self, state: FloxAggregatorState) -> None: Args: state (FloxAggregatorState): The current state of the Aggregator FloxNode. """ + raise NotImplementedError() - # @required + @abstractmethod def agg_param_aggregation( self, state: FloxAggregatorState, @@ -78,7 +80,7 @@ def agg_param_aggregation( StateDict """ - # @required + @abstractmethod def agg_worker_selection( self, state: FloxAggregatorState, children: Iterable[FlockNode], *args, **kwargs ) -> Iterable[FlockNode]: @@ -133,6 +135,7 @@ def agg_after_collect_params( Returns: """ + raise NotImplementedError() #################################################################################### # WORKER CALLBACKS. # @@ -148,7 +151,7 @@ def wrk_on_before_train_step(self, state: FloxWorkerState, *args, **kwargs): Returns: """ - pass + raise NotImplementedError() def wrk_on_after_train_step( self, state: FloxWorkerState, loss: Loss, *args, **kwargs @@ -179,7 +182,7 @@ def wrk_on_before_submit_params( Returns: """ - pass + raise NotImplementedError() def wrk_on_recv_params( self, state: FloxWorkerState, params: StateDict, *args, **kwargs diff --git a/flox/strategies/commons/averaging.py b/flox/strategies/commons/averaging.py index e57993b..d927eaf 100644 --- a/flox/strategies/commons/averaging.py +++ b/flox/strategies/commons/averaging.py @@ -1,3 +1,5 @@ +from collections.abc import Mapping + import numpy as np import torch @@ -6,8 +8,8 @@ def average_state_dicts( - state_dicts: dict[FlockNodeID, StateDict], - weights: dict[FlockNodeID, float] | None = None, + state_dicts: Mapping[FlockNodeID, StateDict], + weights: Mapping[FlockNodeID, float] | None = None, ) -> StateDict: """Averages the parameters given by ``module.state_dict()`` from a set of ``FlockNodes``. @@ -25,7 +27,7 @@ def average_state_dicts( with torch.no_grad(): avg_weights = {} for node, state_dict in state_dicts.items(): - w = 1 / num_nodes if weights is None else weights[node] / weight_sum + w = 1 / num_nodes if weights is None else weights[node] / weight_sum # type: ignore for name, value in state_dict.items(): value = w * torch.clone(value) if name not in avg_weights: diff --git a/flox/strategies/commons/worker_selection.py b/flox/strategies/commons/worker_selection.py index 1bfc1c2..0f19df6 100644 --- a/flox/strategies/commons/worker_selection.py +++ b/flox/strategies/commons/worker_selection.py @@ -1,14 +1,18 @@ +from collections.abc import Iterable +from typing import cast + from numpy.random import RandomState +from numpy.typing import NDArray from flox.flock import FlockNode, FlockNodeKind def random_worker_selection( - children: list[FlockNode], + children: Iterable[FlockNode], participation: float = 1.0, probabilistic: bool = False, always_include_child_aggregators: bool = True, - seed: int = None, + seed: int | None = None, ) -> list[FlockNode]: """ @@ -30,9 +34,9 @@ def random_worker_selection( def fixed_random_worker_selection( - children: list[FlockNode], + children: Iterable[FlockNode], participation: float = 1.0, - seed: int = None, + seed: int | None = None, ) -> list[FlockNode]: """ @@ -46,15 +50,17 @@ def fixed_random_worker_selection( """ rand_state = RandomState(seed) num_selected = min(1, int(participation) * len(list(children))) - selected_children = rand_state.choice(children, size=num_selected, replace=False) + # numpy annotates RandomState.choice too narrowly; need this to satisfy mypy + achildren = cast(NDArray, children) + selected_children = rand_state.choice(achildren, size=num_selected, replace=False) return list(selected_children) def prob_random_worker_selection( - children: list[FlockNode], + children: Iterable[FlockNode], participation: float = 1.0, always_include_child_aggregators: bool = True, - seed: int = None, + seed: int | None = None, ) -> list[FlockNode]: """ @@ -76,7 +82,9 @@ def prob_random_worker_selection( selected_children.append(child) if len(selected_children) == 0: - random_child = rand_state.choice(children) + # numpy annotates RandomState.choice too narrowly; need this to satisfy mypy + achildren = cast(NDArray, children) + random_child = rand_state.choice(achildren) selected_children.append(random_child) return selected_children diff --git a/flox/strategies/registry/fedavg.py b/flox/strategies/registry/fedavg.py index 95ea5fe..7787eb8 100644 --- a/flox/strategies/registry/fedavg.py +++ b/flox/strategies/registry/fedavg.py @@ -1,3 +1,5 @@ +from collections.abc import Mapping + from flox.flock import FlockNodeID from flox.flock.states import FloxAggregatorState, FloxWorkerState, NodeState from flox.strategies.commons.averaging import average_state_dicts @@ -22,7 +24,7 @@ def __init__( participation: float = 1.0, probabilistic: bool = True, always_include_child_aggregators: bool = True, - seed: int = None, + seed: int | None = None, ): """ @@ -46,10 +48,10 @@ def wrk_on_before_train_step(self, state: FloxWorkerState, *args, **kwargs): def agg_param_aggregation( self, state: FloxAggregatorState, - children_states: dict[FlockNodeID, NodeState], - children_state_dicts: dict[FlockNodeID, StateDict], - *args, - **kwargs, + children_states: Mapping[FlockNodeID, NodeState], + children_state_dicts: Mapping[FlockNodeID, StateDict], + *_args, + **_kwargs, ): weights = {} for node, child_state in children_states.items(): diff --git a/flox/strategies/registry/fedprox.py b/flox/strategies/registry/fedprox.py index d6e2ae5..e4444b7 100644 --- a/flox/strategies/registry/fedprox.py +++ b/flox/strategies/registry/fedprox.py @@ -27,7 +27,7 @@ def __init__( participation: float = 1.0, probabilistic: bool = False, always_include_child_aggregators: bool = True, - seed: int = None, + seed: int | None = None, ): """ @@ -74,6 +74,8 @@ def wrk_on_after_train_step( """ global_model = state.pre_local_train_model local_model = state.post_local_train_model + assert global_model is not None + assert local_model is not None params = list(local_model.state_dict().values()) params0 = list(global_model.state_dict().values()) diff --git a/flox/strategies/registry/fedsgd.py b/flox/strategies/registry/fedsgd.py index 377e076..1c88d29 100644 --- a/flox/strategies/registry/fedsgd.py +++ b/flox/strategies/registry/fedsgd.py @@ -1,5 +1,7 @@ from __future__ import annotations +from collections.abc import Iterable, Mapping + from flox.flock import FlockNode, FlockNodeID from flox.flock.states import FloxAggregatorState, NodeState from flox.strategies.base import Strategy @@ -26,7 +28,7 @@ def __init__( participation: float = 1.0, probabilistic: bool = True, always_include_child_aggregators: bool = True, - seed: int = None, + seed: int | None = None, ): """Initializes the FedSGD strategy with the desired parameters. @@ -49,7 +51,11 @@ def __init__( self.seed = seed def agg_worker_selection( - self, state: FloxAggregatorState, children: list[FlockNode], **kwargs + self, + state: FloxAggregatorState, + children: Iterable[FlockNode], + *_args, + **_kwargs, ) -> list[FlockNode]: """Performs a simple average of the model weights returned by the child nodes. @@ -82,10 +88,10 @@ def agg_worker_selection( def agg_param_aggregation( self, state: FloxAggregatorState, - children_states: dict[FlockNodeID, NodeState], - children_state_dicts: dict[FlockNodeID, StateDict], - *args, - **kwargs, + children_states: Mapping[FlockNodeID, NodeState], + children_state_dicts: Mapping[FlockNodeID, StateDict], + *_args, + **_kwargs, ) -> StateDict: """Runs simple, unweighted averaging of ``StateDict`` objects from each child node. diff --git a/flox/utils/random/flock.py b/flox/utils/random/flock.py index 8d06dad..c80c509 100644 --- a/flox/utils/random/flock.py +++ b/flox/utils/random/flock.py @@ -1,6 +1,7 @@ import networkx as nx from flox.flock import Flock +from flox.flock.flock import REQUIRED_ATTRS def random_flock(num_nodes: int, seed: int | None = None) -> Flock: @@ -16,7 +17,7 @@ def random_flock(num_nodes: int, seed: int | None = None) -> Flock: # TODO: Finish this and create a test. tree = nx.random_tree(n=num_nodes, seed=seed, create_using=nx.DiGraph) for node in tree.nodes(): - for attr in Flock.required_attrs: + for attr in REQUIRED_ATTRS: tree.nodes[node][attr] = None flock = Flock(tree) return flock diff --git a/pyproject.toml b/pyproject.toml index 8607cb9..5c5ebc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dependencies = [ ] [project.optional-dependencies] -dev = ["black", "coverage", "jupyterlab", "matplotlib", "numpy", "pytest", "seaborn", "tensorboard", "torchvision"] +dev = ["black", "coverage", "jupyterlab", "matplotlib", "numpy", "pytest", "seaborn", "tensorboard", "torchvision", "matplotlib-stubs", "pandas-stubs", "networkx-stubs"] monitoring = ["tensorboard"] [tool.pytest.ini_options] diff --git a/tox.ini b/tox.ini index e3628fb..1b5067d 100644 --- a/tox.ini +++ b/tox.ini @@ -16,6 +16,7 @@ commands = [testenv:mypy] deps = mypy>=1.6.1 +extras = dev commands = mypy --install-types --non-interactive -p flox {posargs}