diff --git a/flox/flock/__init__.py b/flox/flock/__init__.py index 63a0377..37e7f2f 100644 --- a/flox/flock/__init__.py +++ b/flox/flock/__init__.py @@ -4,14 +4,14 @@ from flox.flock.flock import Flock from flox.flock.node import FlockNode, FlockNodeID, FlockNodeKind -from flox.flock.states import FloxAggregatorState, FloxWorkerState, NodeState +from flox.flock.states import AggrState, WorkerState, NodeState __all__ = [ "Flock", "FlockNode", "FlockNodeID", "FlockNodeKind", - "FloxAggregatorState", - "FloxWorkerState", + "AggrState", + "WorkerState", "NodeState", ] diff --git a/flox/flock/states.py b/flox/flock/states.py index 9a7bc29..7be7a60 100644 --- a/flox/flock/states.py +++ b/flox/flock/states.py @@ -23,7 +23,7 @@ def __post_init__(self): if type(self) is NodeState: raise NotImplementedError( "Cannot instantiate instance of ``NodeState`` (must instantiate instance of " - "subclasses: ``FloxAggregatorState`` or ``FloxWorkerState``)." + "subclasses: ``AggrState`` or ``WorkerState``)." ) def __iter__(self) -> Iterable[str]: @@ -42,7 +42,7 @@ def __setitem__(self, key: str, value: Any) -> None: value (Any): Data to store in ``self.extra_data``. Examples: - >>> state = FloxWorkerState(...) + >>> state = WorkerState(...) >>> state["foo"] = "bar" """ self.cache[key] = value @@ -54,7 +54,7 @@ def __getitem__(self, key: str) -> Any: key (str): Key to retrieve stored data in ``self.extra_data``. Examples: - >>> state = FloxWorkerState(...) + >>> state = WorkerState(...) >>> state["foo"] = "bar" # Stores the data (see `__setitem__()`). >>> print(state["foo"]) # Gets the item. >>> # "foo" @@ -65,14 +65,15 @@ def __getitem__(self, key: str) -> Any: return self.cache[key] -class FloxAggregatorState(NodeState): +class AggrState(NodeState): """State of an Aggregator node in a ``Flock``.""" + # TODO: If there is no difference between ``AggrState`` and ``NodeState``, then do we need the former at all? def __init__(self, idx: FlockNodeID): super().__init__(idx) -class FloxWorkerState(NodeState): +class WorkerState(NodeState): """State of a Worker node in a ``Flock``.""" pre_local_train_model: FloxModule | None = None @@ -92,11 +93,5 @@ def __init__( self.post_local_train_model = post_local_train_model def __repr__(self) -> str: - template = ( - "FloxWorkerState(pre_local_train_model={}, post_local_train_model={})" - ) + template = "WorkerState(pre_local_train_model={}, post_local_train_model={})" return template.format(self.pre_local_train_model, self.post_local_train_model) - - -# NodeState = NewType("NodeState", Union[FloxAggregatorState, FloxWorkerState]) -# """A `Type` included for convenience. It is equivalent to ``Union[FloxAggregatorState, FloxWorkerState]``.""" diff --git a/flox/nn/trainer.py b/flox/nn/trainer.py index b665d63..1a695fc 100644 --- a/flox/nn/trainer.py +++ b/flox/nn/trainer.py @@ -6,7 +6,7 @@ import torch from torch.utils.data import DataLoader -from flox.flock.states import FloxWorkerState +from flox.flock.states import WorkerState from flox.nn import FloxModule from flox.nn.logger.csv import CSVLogger from flox.strategies import Strategy @@ -32,7 +32,7 @@ def fit( train_dataloader: DataLoader, num_epochs: int, strategy: Strategy, - node_state: FloxWorkerState, + node_state: WorkerState, valid_dataloader: DataLoader | None = None, valid_ckpt_path: Path | str | None = None, ) -> pd.DataFrame: diff --git a/flox/runtime/fit.py b/flox/runtime/fit.py index ed0f2dd..cb098fb 100644 --- a/flox/runtime/fit.py +++ b/flox/runtime/fit.py @@ -100,7 +100,7 @@ def federated_fit( strategy=strategy, ) case _: - raise ValueError + raise ValueError("Illegal value for the strategy `kind` parameter.") start_time = datetime.datetime.now() module, history = process.start(debug_mode) diff --git a/flox/runtime/jobs/aggr.py b/flox/runtime/jobs/aggr.py index 1029a5a..0ded832 100644 --- a/flox/runtime/jobs/aggr.py +++ b/flox/runtime/jobs/aggr.py @@ -19,7 +19,7 @@ def aggregation_job( Aggregation results. """ import pandas - from flox.flock.states import FloxAggregatorState, NodeState + from flox.flock.states import AggrState, NodeState from flox.runtime import JobResult child_states: dict[FlockNodeID, NodeState] = {} @@ -29,7 +29,7 @@ def aggregation_job( child_states[idx] = result.node_state child_state_dicts[idx] = result.state_dict - node_state = FloxAggregatorState(node.idx) + node_state = AggrState(node.idx) avg_state_dict = strategy.agg_param_aggregation( node_state, child_states, child_state_dicts ) @@ -59,13 +59,13 @@ def debug_aggregation_job( import datetime import numpy import pandas - from flox.flock.states import FloxAggregatorState + from flox.flock.states import AggrState from flox.runtime import JobResult result = next(iter(results)) state_dict = result.state_dict state_dict = {} if state_dict is None else state_dict - node_state = FloxAggregatorState(node.idx) + node_state = AggrState(node.idx) history = { "node/idx": [node.idx], "node/kind": [node.kind.to_str()], diff --git a/flox/runtime/jobs/train.py b/flox/runtime/jobs/train.py index b9691a5..0523604 100644 --- a/flox/runtime/jobs/train.py +++ b/flox/runtime/jobs/train.py @@ -42,7 +42,7 @@ def local_training_job( Local fitting results. """ from copy import deepcopy - from flox.flock.states import FloxWorkerState + from flox.flock.states import WorkerState from flox.nn.trainer import Trainer from torch.utils.data import DataLoader from flox.runtime import JobResult @@ -66,7 +66,7 @@ def local_training_job( global_model.load_state_dict(module_state_dict) local_model.load_state_dict(module_state_dict) - node_state = FloxWorkerState( + node_state = WorkerState( node.idx, pre_local_train_model=global_model, post_local_train_model=local_model ) @@ -122,11 +122,11 @@ def debug_training_job( import datetime import numpy as np import pandas - from flox.flock.states import FloxWorkerState + from flox.flock.states import WorkerState from flox.runtime import JobResult local_module = module - node_state = FloxWorkerState( + node_state = WorkerState( node.idx, pre_local_train_model=local_module, post_local_train_model=local_module, diff --git a/flox/runtime/launcher/base.py b/flox/runtime/launcher/base.py index befde0a..cffb7d4 100644 --- a/flox/runtime/launcher/base.py +++ b/flox/runtime/launcher/base.py @@ -24,4 +24,5 @@ def submit(self, fn, node: FlockNode, /, *args, **kwargs) -> Future: @abstractmethod def collect(self): + # TODO: Check if this is needed at all. raise NotImplementedError() diff --git a/flox/runtime/process/future_callbacks.py b/flox/runtime/process/future_callbacks.py index 6ab8a6e..6dd19ec 100644 --- a/flox/runtime/process/future_callbacks.py +++ b/flox/runtime/process/future_callbacks.py @@ -1,12 +1,16 @@ +from __future__ import annotations + import functools import typing -from concurrent.futures import Future -from flox.flock import FlockNode -from flox.runtime.jobs import aggregation_job -from flox.runtime.runtime import Runtime -from flox.runtime.utils import set_parent_future -from flox.strategies import Strategy +if typing.TYPE_CHECKING: + from concurrent.futures import Future + + from flox.flock import FlockNode + from flox.runtime.jobs import aggregation_job + from flox.runtime.runtime import Runtime + from flox.runtime.utils import set_parent_future + from flox.strategies import Strategy def all_child_futures_finished_cbk( diff --git a/flox/runtime/process/proc_async.py b/flox/runtime/process/proc_async.py index 756c3a2..144f486 100644 --- a/flox/runtime/process/proc_async.py +++ b/flox/runtime/process/proc_async.py @@ -9,7 +9,7 @@ from flox.data import FloxDataset from flox.flock import Flock, FlockNodeID -from flox.flock.states import FloxAggregatorState, FloxWorkerState, NodeState +from flox.flock.states import AggrState, WorkerState, NodeState from flox.nn import FloxModule from flox.runtime.jobs import local_training_job from flox.runtime.process.proc import BaseProcess @@ -56,7 +56,7 @@ def __init__( self.state_dict = None self.debug_mode = False - self.state = FloxAggregatorState(self.flock.leader.idx) + self.state = AggrState(self.flock.leader.idx) def start(self, debug_mode: bool = False) -> tuple[FloxModule, DataFrame]: if not self.flock.two_tier: @@ -68,7 +68,7 @@ def start(self, debug_mode: bool = False) -> tuple[FloxModule, DataFrame]: worker_state_dicts: dict[FlockNodeID, StateDict] = {} for worker in self.flock.workers: worker_rounds[worker.idx] = 0 - worker_states[worker.idx] = FloxWorkerState(worker.idx) + worker_states[worker.idx] = WorkerState(worker.idx) worker_state_dicts[worker.idx] = self.global_module.state_dict() futures = set() diff --git a/flox/runtime/process/proc_sync.py b/flox/runtime/process/proc_sync.py index 2fd429e..f57a53e 100644 --- a/flox/runtime/process/proc_sync.py +++ b/flox/runtime/process/proc_sync.py @@ -10,7 +10,7 @@ from flox.data import FloxDataset from flox.flock import Flock, FlockNode, FlockNodeKind -from flox.flock.states import FloxAggregatorState +from flox.flock.states import AggrState from flox.nn import FloxModule from flox.runtime.jobs import local_training_job, debug_training_job from flox.runtime.process.future_callbacks import all_child_futures_finished_cbk @@ -123,7 +123,7 @@ def step( raise ValueError(value_err_template.format(kind, idx)) def _aggr_job(self, node: FlockNode) -> Future[Result]: - aggr_state = FloxAggregatorState(node.idx) + aggr_state = AggrState(node.idx) self.strategy.cli_worker_selection(aggr_state, list(self.flock.children(node))) # FIXME: This (^^^) shouldn't be run on the aggregator children_futures = [ diff --git a/flox/strategies/base.py b/flox/strategies/base.py index da37da7..405a1f8 100644 --- a/flox/strategies/base.py +++ b/flox/strategies/base.py @@ -8,7 +8,7 @@ from typing import Iterable, MutableMapping, TypeAlias from flox.flock import FlockNode, FlockNodeID - from flox.flock.states import FloxWorkerState, FloxAggregatorState, NodeState + from flox.flock.states import WorkerState, AggrState, NodeState from flox.nn.typing import StateDict Loss: TypeAlias = torch.Tensor @@ -65,7 +65,7 @@ def cli_get_node_statuses(self): @abstractmethod def cli_worker_selection( - self, state: FloxAggregatorState, children: Iterable[FlockNode], *args, **kwargs + self, state: AggrState, children: Iterable[FlockNode], *args, **kwargs ) -> Iterable[FlockNode]: """ @@ -81,7 +81,7 @@ def cli_worker_selection( return children def cli_before_share_params( - self, state: FloxAggregatorState, state_dict: StateDict, *args, **kwargs + self, state: AggrState, state_dict: StateDict, *args, **kwargs ) -> StateDict: """Callback before sharing parameters to child nodes. @@ -89,7 +89,7 @@ def cli_before_share_params( model parameters, apply noise, personalize, etc. Args: - state (FloxAggregatorState): The current state of the aggregator. + state (AggrState): The current state of the aggregator. state_dict (StateDict): The global model's current StateDict (i.e., parameters) before sharing with workers. @@ -102,18 +102,18 @@ def cli_before_share_params( # AGGREGATOR CALLBACKS. # #################################################################################### - def agg_before_round(self, state: FloxAggregatorState) -> None: + def agg_before_round(self, state: AggrState) -> None: """ Some process to run at the start of a round. Args: - state (FloxAggregatorState): The current state of the Aggregator FloxNode. + state (AggrState): The current state of the Aggregator FloxNode. """ raise NotImplementedError def agg_param_aggregation( self, - state: FloxAggregatorState, + state: AggrState, children_states: MutableMapping[FlockNodeID, NodeState], children_state_dicts: MutableMapping[FlockNodeID, StateDict], *args, @@ -122,7 +122,7 @@ def agg_param_aggregation( """ Args: - state (FloxAggregatorState): + state (AggrState): children_states (Mapping[FlockNodeID, NodeState]): children_state_dicts (Mapping[FlockNodeID, NodeState]): *args (): @@ -138,7 +138,7 @@ def agg_param_aggregation( #################################################################################### def wrk_on_recv_params( - self, state: FloxWorkerState, params: StateDict, *args, **kwargs + self, state: WorkerState, params: StateDict, *args, **kwargs ): """ @@ -153,7 +153,7 @@ def wrk_on_recv_params( """ return params - def wrk_before_train_step(self, state: FloxWorkerState, *args, **kwargs): + def wrk_before_train_step(self, state: WorkerState, *args, **kwargs): """ Args: @@ -167,7 +167,7 @@ def wrk_before_train_step(self, state: FloxWorkerState, *args, **kwargs): raise NotImplementedError() def wrk_after_train_step( - self, state: FloxWorkerState, loss: Loss, *args, **kwargs + self, state: WorkerState, loss: Loss, *args, **kwargs ) -> Loss: """ @@ -183,7 +183,7 @@ def wrk_after_train_step( return loss def wrk_before_submit_params( - self, state: FloxWorkerState, *args, **kwargs + self, state: WorkerState, *args, **kwargs ) -> StateDict: """ diff --git a/flox/strategies/registry/fedavg.py b/flox/strategies/registry/fedavg.py index 2290303..05d471e 100644 --- a/flox/strategies/registry/fedavg.py +++ b/flox/strategies/registry/fedavg.py @@ -9,7 +9,7 @@ from collections.abc import Mapping from flox.flock import FlockNodeID - from flox.flock.states import FloxAggregatorState, FloxWorkerState, NodeState + from flox.flock.states import AggrState, WorkerState, NodeState from flox.nn.typing import StateDict @@ -46,14 +46,14 @@ def __init__( participation, probabilistic, always_include_child_aggregators, seed ) - def wrk_before_train_step(self, state: FloxWorkerState, *args, **kwargs): + def wrk_before_train_step(self, state: WorkerState, *args, **kwargs): if "dataset" not in kwargs: raise ValueError("`dataset` must be provided") state["num_data_samples"] = len(kwargs["dataset"]) def agg_param_aggregation( self, - state: FloxAggregatorState, + state: AggrState, children_states: Mapping[FlockNodeID, NodeState], children_state_dicts: Mapping[FlockNodeID, StateDict], *_args, diff --git a/flox/strategies/registry/fedprox.py b/flox/strategies/registry/fedprox.py index b70e144..4934f9d 100644 --- a/flox/strategies/registry/fedprox.py +++ b/flox/strategies/registry/fedprox.py @@ -1,6 +1,6 @@ import torch -from flox.flock.states import FloxWorkerState +from flox.flock.states import WorkerState from flox.strategies import FedAvg @@ -51,7 +51,7 @@ def __init__( def wrk_after_train_step( self, - state: FloxWorkerState, + state: WorkerState, loss: torch.Tensor, *args, **kwargs, @@ -66,7 +66,7 @@ def wrk_after_train_step( $$ Args: - state (FloxWorkerState): + state (WorkerState): loss (torch.Tensor): **kwargs (): diff --git a/flox/strategies/registry/fedsgd.py b/flox/strategies/registry/fedsgd.py index ff7f301..4494112 100644 --- a/flox/strategies/registry/fedsgd.py +++ b/flox/strategies/registry/fedsgd.py @@ -3,7 +3,7 @@ import typing from flox.flock import FlockNode, FlockNodeID -from flox.flock.states import FloxAggregatorState, NodeState +from flox.flock.states import AggrState, NodeState from flox.strategies.base import Strategy from flox.strategies.commons.averaging import average_state_dicts from flox.strategies.commons.worker_selection import random_worker_selection @@ -55,7 +55,7 @@ def __init__( def agg_worker_selection( self, - state: FloxAggregatorState, + state: AggrState, children: Iterable[FlockNode], *_args, **_kwargs, @@ -73,7 +73,7 @@ def agg_worker_selection( updates from child $k$ at round $t$. Args: - state (FloxAggregatorState): ... + state (AggrState): ... children (list[FlockNode]): ... **kwargs: ... @@ -90,7 +90,7 @@ def agg_worker_selection( def agg_param_aggregation( self, - state: FloxAggregatorState, + state: AggrState, children_states: Mapping[FlockNodeID, NodeState], children_state_dicts: Mapping[FlockNodeID, StateDict], *_args, @@ -99,7 +99,7 @@ def agg_param_aggregation( """Runs simple, unweighted averaging of ``StateDict`` objects from each child node. Args: - state (FloxAggregatorState): ... + state (AggrState): ... children_states (dict[FlockNodeID, NodeState]): ... children_state_dicts (dict[FlockNodeID, StateDict]): ... *args: ... diff --git a/tests/data/test_datasets.py b/tests/data/test_datasets.py index 2246cd8..0c4c3ed 100644 --- a/tests/data/test_datasets.py +++ b/tests/data/test_datasets.py @@ -3,7 +3,6 @@ import pandas as pd import torch from sklearn.datasets import make_classification - # TODO: Get rid of `sklearn` as a dependency. from torch.utils.data import Dataset @@ -62,7 +61,7 @@ def test_dir_datasets(tmpdir): print(data.head()) for worker in flock.workers: - state = FloxWorkerState(worker.idx, None, None) + state = WorkerState(worker.idx, None, None) try: worker_data = MyDataDir(state, tmpdir) assert isinstance(worker_data, Dataset) diff --git a/tests/fit/test_fit_process.py b/tests/fit/test_fit_process.py index c582818..bc9117a 100644 --- a/tests/fit/test_fit_process.py +++ b/tests/fit/test_fit_process.py @@ -3,15 +3,14 @@ import pandas as pd import pytest import torch - from torch import nn from torchvision.datasets import MNIST from torchvision.transforms import ToTensor +from flox import federated_fit +from flox.data.utils import federated_split from flox.flock import Flock from flox.nn import FloxModule -from flox.run import federated_fit -from flox.data.utils import federated_split class MyModule(FloxModule):