diff --git a/fashion_mnist_demo.py b/fashion_mnist_demo.py index 5220006..7cf3655 100644 --- a/fashion_mnist_demo.py +++ b/fashion_mnist_demo.py @@ -61,7 +61,6 @@ def main(): # where="local", # "globus_compute", ) df.to_feather(Path("out/fashion_mnist_demo.feather")) - print(">>> Finished!") if __name__ == "__main__": diff --git a/flox/flock/states.py b/flox/flock/states.py index bce0bf5..4fbdc62 100644 --- a/flox/flock/states.py +++ b/flox/flock/states.py @@ -76,22 +76,22 @@ def __init__(self, idx: NodeID): class WorkerState(NodeState): """State of a Worker node in a ``Flock``.""" - pre_local_train_model: FloxModule | None = None + global_model: FloxModule | None = None """Global model.""" - post_local_train_model: FloxModule | None = None + local_model: FloxModule | None = None """Local model after local fitting/training.""" def __init__( self, idx: NodeID, - pre_local_train_model: FloxModule | None = None, - post_local_train_model: FloxModule | None = None, + global_model: FloxModule | None = None, + local_model: FloxModule | None = None, ): super().__init__(idx) - self.pre_local_train_model = pre_local_train_model - self.post_local_train_model = post_local_train_model + self.global_model = global_model + self.local_model = local_model def __repr__(self) -> str: - template = "WorkerState(pre_local_train_model={}, post_local_train_model={})" - return template.format(self.pre_local_train_model, self.post_local_train_model) + template = "WorkerState(global_model={}, local_model={})" + return template.format(self.global_model, self.local_model) diff --git a/flox/jobs/__init__.py b/flox/jobs/__init__.py index 269ed85..b7e5101 100644 --- a/flox/jobs/__init__.py +++ b/flox/jobs/__init__.py @@ -17,7 +17,7 @@ Job: t.TypeAlias = AggregableJob | TrainableJob | NodeCallable """ An umbrella typing that encapsulates both ``AggregableJob`` and ``TrainableJob`` protocols -for job implementations for both the aggregator and worker nodes (respectively). +for job impl for both the aggregator and worker nodes (respectively). """ @@ -27,7 +27,7 @@ "AggregableJob", "TrainableJob", "NodeCallable", - # Job implementations. + # Job impl. "AggregateJob", "DebugAggregateJob", "LocalTrainJob", diff --git a/flox/jobs/aggregation.py b/flox/jobs/aggregation.py index 7f4ae56..b9e6017 100644 --- a/flox/jobs/aggregation.py +++ b/flox/jobs/aggregation.py @@ -2,7 +2,7 @@ from flox.jobs.protocols import AggregableJob from flox.runtime.result import Result from flox.runtime.transfer import BaseTransfer -from flox.strategies_depr import Strategy +from flox.strategies import AggregatorStrategy class AggregateJob(AggregableJob): @@ -10,7 +10,7 @@ class AggregateJob(AggregableJob): def __call__( node: FlockNode, transfer: BaseTransfer, - strategy: Strategy, + aggr_strategy: AggregatorStrategy, results: list[Result], ) -> Result: """Aggregate the state dicts from each of the results. @@ -18,7 +18,7 @@ def __call__( Args: node (FlockNode): The aggregator node. transfer (Transfer): ... - strategy (Strategy): ... + aggr_strategy (AggregatorStrategy): ... results (list[JobResult]): Results from children of ``node``. Returns: @@ -33,10 +33,10 @@ def __call__( for result in results: idx: NodeID = result.node_idx child_states[idx] = result.node_state - child_state_dicts[idx] = result.state_dict + child_state_dicts[idx] = result.params node_state = AggrState(node.idx) - avg_state_dict = strategy.agg_param_aggregation( + avg_state_dict = aggr_strategy.aggregate_params( node_state, child_states, child_state_dicts ) @@ -64,7 +64,7 @@ class DebugAggregateJob(AggregableJob): def __call__( node: FlockNode, transfer: BaseTransfer, - strategy: Strategy, + aggr_strategy: AggregatorStrategy, results: list[Result], ) -> Result: """ @@ -72,7 +72,7 @@ def __call__( Args: node (): transfer (): - strategy (): + aggr_strategy (): results (): Returns: @@ -85,7 +85,7 @@ def __call__( from flox.runtime import JobResult result = next(iter(results)) - state_dict = result.state_dict + state_dict = result.params state_dict = {} if state_dict is None else state_dict node_state = AggrState(node.idx) history = { diff --git a/flox/jobs/local_training.py b/flox/jobs/local_training.py index 27fc667..957ec9e 100644 --- a/flox/jobs/local_training.py +++ b/flox/jobs/local_training.py @@ -8,7 +8,7 @@ from flox.data import FloxDataset from flox.flock import FlockNode from flox.nn import FloxModule - from flox.nn.typing import StateDict + from flox.nn.typing import Params from flox.runtime import Result from flox.runtime.transfer import BaseTransfer from flox.strategies import WorkerStrategy, TrainerStrategy @@ -19,8 +19,8 @@ class LocalTrainJob(TrainableJob): def __call__( node: FlockNode, parent: FlockNode, - module: FloxModule, - module_state_dict: StateDict, + global_model: FloxModule, + module_state_dict: Params, dataset: FloxDataset, transfer: BaseTransfer, worker_strategy: WorkerStrategy, @@ -35,7 +35,7 @@ def __call__( parent (FlockNode): strategy (Strategy): module (FloxModule): - module_state_dict (StateDict): + module_state_dict (Params): dataset (Dataset | Subset | None): **train_hyper_params (): @@ -44,23 +44,22 @@ def __call__( """ from copy import deepcopy from flox.flock.states import WorkerState - from flox.nn.trainer import Trainer + from flox.nn.model_trainer import Trainer from torch.utils.data import DataLoader from flox.runtime import JobResult - global_model = module - global_state_dict = module.state_dict() - local_model = deepcopy(module) + # global_state_dict = global_model.state_dict() + local_model = deepcopy(global_model) global_model.load_state_dict(module_state_dict) local_model.load_state_dict(module_state_dict) - node_state = WorkerState( + state = WorkerState( node.idx, - pre_local_train_model=global_model, - post_local_train_model=local_model, + global_model=global_model, + local_model=local_model, ) + state = worker_strategy.work_start(state) # NOTE: Double-check. - worker_strategy.work_start() data = dataset.load(node) train_dataloader = DataLoader( data, @@ -68,10 +67,10 @@ def __call__( shuffle=train_hyper_params.get("shuffle", True), ) - # Add optimizer to this strategy. - worker_strategy.before_training(node_state, data) - trainer = Trainer() - optimizer = local_model.configure_optimizer() + trainer = Trainer(trainer_strategy) + optimizer = local_model.configure_optimizers() + + state, data = worker_strategy.before_training(state, data) history = trainer.fit( local_model, optimizer, @@ -79,18 +78,23 @@ def __call__( # TODO: Include `trainer_params` as an argument to # this so users can easily customize Trainer. num_epochs=train_hyper_params.get("num_epochs", 2), - node_state=node_state, - trainer_strategy=trainer_strategy, + node_state=state, ) - local_params = worker_strategy.after_training(node_state) + state = worker_strategy.after_training(state) # NOTE: Double-check. + ################################################################################ + # TRAINING DATA POST-PROCESSING + ################################################################################ history["node/idx"] = node.idx history["node/kind"] = node.kind.to_str() history["parent/idx"] = parent.idx history["parent/kind"] = parent.kind.to_str() - result = JobResult(node_state, node.idx, node.kind, local_params, history) + local_params = state.local_model.state_dict() + result = JobResult(state, node.idx, node.kind, local_params, history) + + result = worker_strategy.work_end(result) # NOTE: Double-check. return transfer.report(result) @@ -99,8 +103,8 @@ class DebugLocalTrainJob(TrainableJob): def __call__( node: FlockNode, parent: FlockNode, - module: FloxModule, - module_state_dict: StateDict, + global_model: FloxModule, + module_state_dict: Params, dataset: FloxDataset, transfer: BaseTransfer, worker_strategy: WorkerStrategy, @@ -128,8 +132,8 @@ def __call__( local_module = module node_state = WorkerState( node.idx, - pre_local_train_model=local_module, - post_local_train_model=local_module, + global_model=local_module, + local_model=local_module, ) history = { "node/idx": [node.idx], diff --git a/flox/jobs/protocols.py b/flox/jobs/protocols.py index 078cb38..2e3035b 100644 --- a/flox/jobs/protocols.py +++ b/flox/jobs/protocols.py @@ -5,7 +5,7 @@ 1. aggregation jobs (``AggregableJob``) 2. local training jobs (``TrainableJob``) -These protocols can be used to define custom implementations of aggregation jobs for highly-customized FLoX processes. +These protocols can be used to define custom impl of aggregation jobs for highly-customized FLoX processes. However, this is not necessary for the vast majority of imaginable cases. Should users choose to do this, it is up to the user's discretion to do so safely and correctly. @@ -27,11 +27,14 @@ from flox.data import FloxDataset from flox.flock import FlockNode from flox.nn import FloxModule - from flox.nn.typing import StateDict + from flox.nn.typing import Params from flox.runtime import Result from flox.runtime.transfer import BaseTransfer - from flox.strategies import WorkerStrategy, TrainerStrategy - from flox.strategies_depr import Strategy + from flox.strategies import ( + WorkerStrategy, + TrainerStrategy, + AggregatorStrategy, + ) class NodeCallable(t.Protocol): @@ -46,11 +49,11 @@ def __call__(self, node: FlockNode, *args, **kwargs) -> t.Any: @t.runtime_checkable class AggregableJob(t.Protocol): """ - A protocol that defines functions that are valid implementations to be used for model aggregation in + A protocol that defines functions that are valid impl to be used for model aggregation in launching FLoX processes. Notes: - FLoX provides default implementations of this protocol via + FLoX provides default impl of this protocol via [AggregateJob][flox.jobs.aggregation.AggregateJob] and [DebugAggregateJob][flox.jobs.aggregation.DebugAggregateJob]. """ @@ -59,7 +62,7 @@ class AggregableJob(t.Protocol): def __call__( node: FlockNode, transfer: BaseTransfer, - strategy: Strategy, + aggr_strategy: AggregatorStrategy, results: list[Result], ) -> Result: """ @@ -68,7 +71,7 @@ def __call__( Args: node (FlockNode): transfer (BaseTransfer): - strategy (Strategy): + aggr_strategy (AggregatorStrategy): results (list[Result]): Returns: @@ -79,11 +82,11 @@ def __call__( @t.runtime_checkable class TrainableJob(t.Protocol): """ - A protocol that defines functions that are valid implementations to be used for local training in + A protocol that defines functions that are valid impl to be used for local training in launching FLoX processes. Notes: - FLoX provides default implementations of this protocol via + FLoX provides default impl of this protocol via [LocalTrainJob][flox.jobs.local_training.LocalTrainJob] and [DebugLocalTrainJob][flox.jobs.local_training.DebugLocalTrainJob]. """ @@ -92,8 +95,8 @@ class TrainableJob(t.Protocol): def __call__( node: FlockNode, parent: FlockNode, - module: FloxModule, - module_state_dict: StateDict, + global_model: FloxModule, + module_state_dict: Params, dataset: FloxDataset, transfer: BaseTransfer, worker_strategy: WorkerStrategy, diff --git a/flox/jobs/train_depr.py b/flox/jobs/train_depr.py deleted file mode 100644 index 57c5c92..0000000 --- a/flox/jobs/train_depr.py +++ /dev/null @@ -1,158 +0,0 @@ -from __future__ import annotations - -import typing - -from torch.utils.data import Dataset, Subset - -from flox.flock import FlockNode -from flox.nn import FloxModule -from flox.runtime.result import Result -from flox.runtime.transfer import BaseTransfer -from flox.strategies_depr import Strategy - -if typing.TYPE_CHECKING: - from flox.data import FloxDataset - from flox.nn.typing import StateDict - - -# TODO: Debug training job should have the same signature. -def local_training_job( - node: FlockNode, - transfer: BaseTransfer, - parent: FlockNode, - strategy: Strategy, - module: FloxModule, - module_state_dict: StateDict, - dataset: FloxDataset, # TODO: Cannot be `None`. - **train_hyper_params, -) -> Result: - """Perform local training on a worker node. - - Args: - node (FlockNode): - transfer (BaseTransfer): ... - parent (FlockNode): - strategy (Strategy): - module (FloxModule): - module_state_dict (StateDict): - dataset (Dataset | Subset | None): - **train_hyper_params (): - - Returns: - Local fitting results. - """ - from copy import deepcopy - from flox.flock.states import WorkerState - from flox.nn.trainer import Trainer - from torch.utils.data import DataLoader - from flox.runtime import JobResult - - # if isinstance(dataset, LocalDatasetV2): - # data = dataset.load() - # elif isinstance(dataset, FederatedSubsets): - # data = dataset[node.idx] - - # match dataset: - # case LocalDataset(): - # data = ... - # case FederatedSubsets(): - # data = ... - # case _: - # raise ValueError("...") - - global_model = module - global_state_dict = module.state_dict() - local_model = deepcopy(module) - global_model.load_state_dict(module_state_dict) - local_model.load_state_dict(module_state_dict) - - node_state = WorkerState( - node.idx, pre_local_train_model=global_model, post_local_train_model=local_model - ) - - try: - strategy.wrk_on_recv_params(node_state, global_state_dict) - except NotImplementedError: - pass - - data = dataset.load(node) - train_loader = DataLoader( - data, - batch_size=train_hyper_params.get("batch_size", 32), - shuffle=train_hyper_params.get("shuffle", True), - ) - - try: - # Add optimizer to this strategy. - strategy.wrk_before_train_step(state=node_state, dataset=data) - except NotImplementedError: - pass - trainer = Trainer() - # optimizer = model.configure_optimizer() # TODO - history = trainer.fit( - local_model, - train_loader, - # optimizer=optimizer, TODO - # TODO: Include `trainer_params` as an argument to - # this so users can easily customize Trainer. - num_epochs=train_hyper_params.get("num_epochs", 2), - node_state=node_state, - strategy=strategy, - ) - - local_params = strategy.wrk_before_submit_params(node_state) - - history["node/idx"] = node.idx - history["node/kind"] = node.kind.to_str() - history["parent/idx"] = parent.idx - history["parent/kind"] = parent.kind.to_str() - - result = JobResult(node_state, node.idx, node.kind, local_params, history) - return transfer.report(result) - - -def debug_training_job( - node: FlockNode, - transfer: BaseTransfer, - parent: FlockNode, - strategy: Strategy, - module: FloxModule, -): # -> Result: - """ - - Args: - node (): - transfer (): - parent (): - strategy (): - module (FloxModule): ... - - Returns: - - """ - import datetime - import numpy as np - import pandas - from flox.flock.states import WorkerState - from flox.runtime import JobResult - - local_module = module - node_state = WorkerState( - node.idx, - pre_local_train_model=local_module, - post_local_train_model=local_module, - ) - history = { - "node/idx": [node.idx], - "node/kind": [node.kind.to_str()], - "parent/idx": [parent.idx], - "parent/kind": [parent.kind.to_str()], - "train/loss": [np.nan], - "train/epoch": [np.nan], - "train/batch_idx": [np.nan], - "train/time": [datetime.datetime.now()], - "mode": "debug", - } - history_df = pandas.DataFrame.from_dict(history) - result = JobResult(node_state, node.idx, node.kind, module.state_dict(), history_df) - return transfer.report(result) diff --git a/flox/nn/__init__.py b/flox/nn/__init__.py index 530f49a..94e82c8 100644 --- a/flox/nn/__init__.py +++ b/flox/nn/__init__.py @@ -1,4 +1,4 @@ from flox.nn.model import FloxModule -from flox.nn.trainer import Trainer +from flox.nn.model_trainer import Trainer __all__ = ["FloxModule", "Trainer"] diff --git a/flox/nn/trainer.py b/flox/nn/model_trainer.py similarity index 81% rename from flox/nn/trainer.py rename to flox/nn/model_trainer.py index ff35ac5..5a735b3 100644 --- a/flox/nn/trainer.py +++ b/flox/nn/model_trainer.py @@ -1,6 +1,5 @@ import datetime from pathlib import Path -from typing import Any import pandas as pd import torch @@ -15,16 +14,11 @@ class Trainer: def __init__( self, - logger: str = "csv", - device="cpu", - config: dict[str, Any] | None = None, + trainer_strategy: TrainerStrategy, ): - self.device = device - self.config = config # TODO: Not implemented to do anything at the moment. - if logger == "csv": - self.logger = CSVLogger() - else: - raise ValueError("Illegal value for `logger`.") + self.trainer_strategy = trainer_strategy + self.device = "cpu" + self.logger = CSVLogger() def fit( self, @@ -32,7 +26,6 @@ def fit( optimizer: torch.optim.Optimizer, train_dataloader: DataLoader, num_epochs: int, - trainer_strategy: TrainerStrategy, node_state: WorkerState, valid_dataloader: DataLoader | None = None, valid_ckpt_path: Path | str | None = None, @@ -44,11 +37,11 @@ def fit( for epoch in range(num_epochs): for batch_idx, batch in enumerate(train_dataloader): loss = model.training_step(batch, batch_idx) - loss = trainer_strategy.before_backprop(loss) + loss = self.trainer_strategy.before_backprop(node_state, loss) optimizer.zero_grad() loss.backward() - loss = trainer_strategy.after_backprop(loss) + loss = self.trainer_strategy.after_backprop(node_state, loss) optimizer.step() # log data about training @@ -74,8 +67,8 @@ def test( ckpt_path: Path | str | None = None, ): with torch.no_grad(): - for i, batch in enumerate(test_dataloader): - model.test_step(batch, i) + for batch_idx, batch in enumerate(test_dataloader): + model.test_step(batch, batch_idx) def validate( self, diff --git a/flox/nn/typing.py b/flox/nn/typing.py index 71389c5..3bcbcde 100644 --- a/flox/nn/typing.py +++ b/flox/nn/typing.py @@ -1,8 +1,14 @@ -from typing import TypeAlias, Literal +from __future__ import annotations + +import typing as t import torch -Kind: TypeAlias = Literal["async", "sync"] -Where: TypeAlias = Literal["local", "globus_compute"] -StateDict = dict[str, torch.Tensor] -"""The state dict of PyTorch ``torch.nn.Module`` (see ``torch.nn.Module.state_dict()``).""" +Kind: t.TypeAlias = t.Literal["async", "sync"] + +Where: t.TypeAlias = t.Literal["local", "globus_compute"] + +Params: t.TypeAlias = dict[str, torch.Tensor] # torch.optim.Optimizer.StateDict +"""The state dict of PyTorch ``torch.nn.Module`` (see ``torch.nn.Module.params()``).""" + +Loss: t.TypeAlias = torch.Tensor diff --git a/flox/runtime/fit.py b/flox/runtime/fit.py index a78d9fa..9f805b0 100644 --- a/flox/runtime/fit.py +++ b/flox/runtime/fit.py @@ -1,8 +1,9 @@ import datetime -import typing +import typing as t from pandas import DataFrame +import flox.strategies as strats from flox.data import FloxDataset from flox.flock import Flock from flox.nn import FloxModule @@ -18,7 +19,6 @@ from flox.runtime.process.proc_sync import SyncProcess from flox.runtime.runtime import Runtime from flox.runtime.transfer import BaseTransfer -from flox.strategies_depr import Strategy def create_launcher(kind: str, **launcher_cfg) -> Launcher: @@ -44,10 +44,16 @@ def federated_fit( module: FloxModule, datasets: FloxDataset, num_global_rounds: int, - strategy: Strategy | str | None = None, + # Strategy arguments. + strategy: strats.Strategy | str | None = None, + client_strategy: strats.ClientStrategy | None = None, + aggr_strategy: strats.AggregatorStrategy | None = None, + worker_strategy: strats.WorkerStrategy | None = None, + trainer_strategy: strats.TrainerStrategy | None = None, + # Process arguments. kind: Kind = "sync", launcher_kind: str = "process", - launcher_cfg: dict[str, typing.Any] | None = None, + launcher_cfg: dict[str, t.Any] | None = None, debug_mode: bool = False, ) -> tuple[FloxModule, DataFrame]: """ @@ -58,9 +64,13 @@ def federated_fit( datasets (FloxDataset): num_global_rounds (int): strategy (Strategy | str | None): + client_strategy (strats.ClientStrategy): ... + aggr_strategy (strats.AggregatorStrategy): ... + worker_strategy (strats.WorkerStrategy): ... + trainer_strategy (strats.TrainerStrategy): ... kind (Kind): launcher_kind (str): - launcher_cfg (dict[str, typing.Any] | None): + launcher_cfg (dict[str, t.Any] | None): debug_mode (bool): ... Returns: @@ -71,11 +81,13 @@ def federated_fit( launcher = create_launcher(launcher_kind, **launcher_cfg) transfer = BaseTransfer() runtime = Runtime(launcher, transfer) - - if strategy is None: - strategy = "fedsgd" - if isinstance(strategy, str): - strategy = Strategy.get_strategy(strategy)() + parsed_strategy = parse_strategy_args( + strategy=strategy, + client_strategy=client_strategy, + aggr_strategy=aggr_strategy, + worker_strategy=worker_strategy, + trainer_strategy=trainer_strategy, + ) # runner = runner_factory.build(kind, ...) # runner.start() @@ -88,7 +100,7 @@ def federated_fit( num_global_rounds=num_global_rounds, module=module, dataset=datasets, - strategy=strategy, + strategy=parsed_strategy, ) case "async": process = AsyncProcess( @@ -97,7 +109,7 @@ def federated_fit( num_global_rounds=num_global_rounds, module=module, dataset=datasets, - strategy=strategy, + strategy=parsed_strategy, ) case _: raise ValueError("Illegal value for the strategy `kind` parameter.") @@ -107,3 +119,43 @@ def federated_fit( history["train/rel_time"] = history["train/time"] - start_time history["train/rel_time"] = history["train/rel_time"].dt.total_seconds() return module, history + + +def parse_strategy_args( + strategy: strats.Strategy | str | None, + client_strategy: strats.ClientStrategy | None, + aggr_strategy: strats.AggregatorStrategy | None, + worker_strategy: strats.WorkerStrategy | None, + trainer_strategy: strats.TrainerStrategy | None, + **kwargs, +) -> strats.Strategy: + if isinstance(strategy, strats.Strategy): + return strategy + + if isinstance(strategy, str): + return strats.load_strategy(strategy, **kwargs) + + if strategy is not None: + raise ValueError( + "Argument ``strategy`` is not a legal value. Must be either a ``Strategy``, " + "a supported string value, or ``None``. " + ) + + # If the user provided each individual strategy implementations, then we must first check and confirm + # that none of those arguments are ``None``. If they are not, then we can package them as a single + # ``Strategy`` and return that. + strategies = [client_strategy, aggr_strategy, worker_strategy, trainer_strategy] + for _name, _strategy in zip(["client", "aggr", "worker", "trainer"], strategies): + if _strategy is None: + cls_name = "aggregator" if _name == "aggr" else _name + cls_name = cls_name.title() + raise ValueError( + f"Argument `{_name}_strategy` must be a class that implements protocol ``{cls_name}``." + ) + + return strats.Strategy( + client_strategy=client_strategy, + aggr_strategy=aggr_strategy, + worker_strategy=worker_strategy, + trainer_strategy=trainer_strategy, + ) diff --git a/flox/runtime/launcher/local.py b/flox/runtime/launcher/local.py index a3a1562..ae502fe 100644 --- a/flox/runtime/launcher/local.py +++ b/flox/runtime/launcher/local.py @@ -26,7 +26,13 @@ def __init__(self, pool: str, n_workers: int = 1): ) def submit(self, fn: NodeCallable, node: FlockNode, /, *args, **kwargs) -> Future: + # TODO: Adjust this typing (i.e., Future is not always returned in the case where `n_workers == 1`. + # Then clarify the logic behind this. return self.pool.submit(fn, node, *args, **kwargs) + # if self.n_workers > 1: + # return self.pool.submit(fn, node, *args, **kwargs) + # else: + # return fn(node, *args, **kwargs) def collect(self): pass diff --git a/flox/runtime/launcher/sequential.py b/flox/runtime/launcher/sequential.py new file mode 100644 index 0000000..afc882d --- /dev/null +++ b/flox/runtime/launcher/sequential.py @@ -0,0 +1,11 @@ +class SequentialLauncher: + """ + A Launcher implementation that does not rely on Futures or any of the concurrent execution frameworks + native to Python. This simply runs jobs one at a time. This is useful for debugging and confirming + whether your defined FL process is able to run properly. Error messages using the Executors from the + ``concurrent.futures`` module are a bit cryptic. So this is meant to alleviate this concern. + """ + + def __init__(self): + # TODO + raise NotImplementedError diff --git a/flox/runtime/process/future_callbacks.py b/flox/runtime/process/future_callbacks.py index e269902..11fbb28 100644 --- a/flox/runtime/process/future_callbacks.py +++ b/flox/runtime/process/future_callbacks.py @@ -3,34 +3,34 @@ import functools import typing -from flox.jobs import AggregateJob from flox.runtime.utils import set_parent_future if typing.TYPE_CHECKING: + from flox.jobs import Job from concurrent.futures import Future from flox.flock import FlockNode from flox.runtime.runtime import Runtime - from flox.strategies_depr import Strategy + from flox.strategies import AggregatorStrategy def all_child_futures_finished_cbk( + job: Job, parent_future: Future, children_futures: typing.Iterable[Future], node: FlockNode, runtime: Runtime, - strategy: Strategy, + aggr_strategy: AggregatorStrategy, _: Future, ): if all([child_future.done() for child_future in children_futures]): # TODO: We need to add error-handling for cases when the # `TaskExecutionFailed` error from Globus-Compute is thrown. children_results = [child_future.result() for child_future in children_futures] - job = AggregateJob() future = runtime.submit( job, node, - strategy=strategy, + aggr_strategy=aggr_strategy, results=children_results, ) aggr_done_callback = functools.partial(set_parent_future, parent_future) diff --git a/flox/runtime/process/proc_async.py b/flox/runtime/process/proc_async.py index c085553..6db9f0a 100644 --- a/flox/runtime/process/proc_async.py +++ b/flox/runtime/process/proc_async.py @@ -14,10 +14,10 @@ from flox.nn import FloxModule from flox.runtime.process.proc import BaseProcess from flox.runtime.runtime import Runtime -from flox.strategies_depr import Strategy +from flox.strategies import Strategy if typing.TYPE_CHECKING: - from flox.nn.typing import StateDict + from flox.nn.typing import Params class AsyncProcess(BaseProcess): @@ -67,7 +67,7 @@ def start(self, debug_mode: bool = False) -> tuple[FloxModule, DataFrame]: histories: list[DataFrame] = [] worker_rounds: dict[NodeID, int] = {} worker_states: dict[NodeID, NodeState] = {} - worker_state_dicts: dict[NodeID, StateDict] = {} + worker_state_dicts: dict[NodeID, Params] = {} for worker in self.flock.workers: worker_rounds[worker.idx] = 0 worker_states[worker.idx] = WorkerState(worker.idx) @@ -108,7 +108,7 @@ def start(self, debug_mode: bool = False) -> tuple[FloxModule, DataFrame]: worker = self.flock[result.node_idx] worker_states[worker.idx] = result.node_state - worker_state_dicts[worker.idx] = result.state_dict + worker_state_dicts[worker.idx] = result.params result.history["round"] = worker_rounds[result.node_idx] histories.append(result.history) avg_state_dict = self.strategy.agg_param_aggregation( diff --git a/flox/runtime/process/proc_sync.py b/flox/runtime/process/proc_sync.py index d986f68..b5b9c19 100644 --- a/flox/runtime/process/proc_sync.py +++ b/flox/runtime/process/proc_sync.py @@ -11,16 +11,16 @@ from flox.data import FloxDataset from flox.flock import Flock, FlockNode, NodeKind from flox.flock.states import AggrState -from flox.jobs import LocalTrainJob, DebugLocalTrainJob +from flox.jobs import LocalTrainJob, DebugLocalTrainJob, AggregateJob from flox.nn import FloxModule from flox.runtime.process.future_callbacks import all_child_futures_finished_cbk from flox.runtime.process.proc import BaseProcess from flox.runtime.result import Result from flox.runtime.runtime import Runtime -from flox.strategies_depr import Strategy +from flox.strategies import Strategy if typing.TYPE_CHECKING: - from flox.nn.typing import StateDict + from flox.nn.typing import Params class SyncProcess(BaseProcess): @@ -34,7 +34,7 @@ class SyncProcess(BaseProcess): strategy: Strategy dataset: FloxDataset aggr_callback: typing.Any # TODO: Fix - state_dict: StateDict | None + params: Params | None debug_mode: bool pbar_desc: str @@ -55,35 +55,31 @@ def __init__( self.dataset = dataset self.aggr_callback = all_child_futures_finished_cbk - self.state_dict = None + self.params = None self.debug_mode = False self.pbar_desc = "federated_fit::sync" # TODO: Add description option for the progress bar when it's training. # Also, add a configurable stop condition - def start( - self, testing_mode: bool = False - ) -> tuple[FloxModule, DataFrame]: # , global_module: FloxModule): + def start(self, testing_mode: bool = False) -> tuple[FloxModule, DataFrame]: if testing_mode: from flox.runtime.process.debug_utils import DebugModule self.debug_mode = True self.global_module = DebugModule() - dataframes = [] + histories = [] progress_bar = tqdm(total=self.num_global_rounds, desc=self.pbar_desc) for round_num in range(self.num_global_rounds): - self.state_dict = self.global_module.state_dict() - future = self.step() - update = future.result() - history = update.history - history["round"] = round_num - dataframes.append(history) - self.global_module.load_state_dict(update.state_dict) + self.params = self.global_module.state_dict() + step_result = self.step().result() + step_result.history["round"] = round_num + histories.append(step_result.history) + self.global_module.load_state_dict(step_result.params) progress_bar.update() - history = pd.concat(dataframes) + history = pd.concat(histories) return self.global_module, history def step( @@ -104,28 +100,33 @@ def step( match flock.get_kind(node): case NodeKind.LEADER | NodeKind.AGGREGATOR: - if self.debug_mode: - # return self._debug_aggr_job(node) # FIXME - return self._aggr_job(node) - else: - return self._aggr_job(node) + return self.submit_aggr_job(node) + # if self.debug_mode: + # return self.submit_aggr_debug_job(node) # FIXME + # else: + # return self.submit_aggr_job(node) case NodeKind.WORKER: assert parent is not None # (^^^) avoids mypy issue which won't naturally occur with valid Flock topo if self.debug_mode: - return self._debug_worker_job(node, parent) + return self.submit_worker_debug_job(node, parent) else: - return self._worker_job(node, parent) + return self.submit_worker_job(node, parent) case _: kind = flock.get_kind(node) idx = node.idx raise ValueError(value_err_template.format(kind, idx)) - def _aggr_job(self, node: FlockNode) -> Future[Result]: + ######################################################################################################## + ######################################################################################################## + + def submit_aggr_job(self, node: FlockNode) -> Future[Result]: aggr_state = AggrState(node.idx) - self.strategy.cli_worker_selection(aggr_state, list(self.flock.children(node))) + self.strategy.client_strategy.select_worker_nodes( + aggr_state, list(self.flock.children(node)), None + ) # FIXME: This (^^^) shouldn't be run on the aggregator children_futures = [ self.step(node=child, parent=node) for child in self.flock.children(node) @@ -137,41 +138,46 @@ def _aggr_job(self, node: FlockNode) -> Future[Result]: # future. But, it will only perform aggregation once since only the last future # to be completed will activate the conditional. future: Future[Result] = Future() + job = AggregateJob() subtree_done_cbk = functools.partial( self.aggr_callback, + job, future, children_futures, node, self.runtime, - self.strategy, + self.strategy.aggr_strategy, ) for child_fut in children_futures: child_fut.add_done_callback(subtree_done_cbk) return future - def _debug_aggr_job(self, node: FlockNode) -> Future[Result]: + def submit_aggr_debug_job(self, node: FlockNode) -> Future[Result]: raise NotImplementedError - def _worker_job(self, node: FlockNode, parent: FlockNode) -> Future[Result]: + def submit_worker_job(self, node: FlockNode, parent: FlockNode) -> Future[Result]: job = LocalTrainJob() data = self.dataset return self.runtime.submit( job, node, parent=parent, - module=self.global_module, - strategy=self.strategy, + global_model=self.global_module, + worker_strategy=self.strategy.worker_strategy, + trainer_strategy=self.strategy.trainer_strategy, dataset=self.runtime.proxy(data), - module_state_dict=self.runtime.proxy(self.state_dict), + module_state_dict=self.runtime.proxy(self.params), ) - def _debug_worker_job(self, node: FlockNode, parent: FlockNode) -> Future[Result]: + def submit_worker_debug_job( + self, node: FlockNode, parent: FlockNode + ) -> Future[Result]: job = DebugLocalTrainJob() return self.runtime.submit( job, node, parent=parent, - module=self.global_module, + global_model=self.global_module, strategy=self.strategy, ) diff --git a/flox/runtime/result.py b/flox/runtime/result.py index 61a0d66..1295384 100644 --- a/flox/runtime/result.py +++ b/flox/runtime/result.py @@ -10,7 +10,7 @@ from flox.flock import NodeID, NodeKind from flox.flock.states import NodeState - from flox.nn.typing import StateDict + from flox.nn.typing import Params @dataclass @@ -29,8 +29,8 @@ class JobResult: node_kind: NodeKind """The kind of the ``Flock`` node.""" - state_dict: StateDict - """The ``StateDict`` of the PyTorch global_module (either aggregated or trained locally).""" + params: Params + """The ``Params`` of the PyTorch global_module (either aggregated or trained locally).""" history: DataFrame """The history of results.""" diff --git a/flox/runtime/runtime.py b/flox/runtime/runtime.py index 4d5cf3c..5630948 100644 --- a/flox/runtime/runtime.py +++ b/flox/runtime/runtime.py @@ -1,16 +1,16 @@ +import typing as t from concurrent.futures import Future -from typing import Any, NewType from flox.flock import FlockNode from flox.jobs import Job from flox.runtime.launcher import Launcher from flox.runtime.transfer import BaseTransfer -Config = NewType("Config", dict[str, Any]) +Config = t.NewType("Config", dict[str, t.Any]) class Borg: - _shared_state: dict[str, Any] = {} + _shared_state: dict[str, t.Any] = {} def __init__(self): self.__dict__ = self._shared_state @@ -29,7 +29,7 @@ def __init__(self, launcher: Launcher, transfer: BaseTransfer): def submit(self, fn: Job, node: FlockNode, /, *args, **kwargs) -> Future: return self.launcher.submit(fn, node, *args, **kwargs, transfer=self.transfer) - def proxy(self, data: Any): + def proxy(self, data: t.Any): return self.transfer.proxy(data) # @classmethod diff --git a/flox/runtime/transfer/base.py b/flox/runtime/transfer/base.py index 07e62c6..85b46ff 100644 --- a/flox/runtime/transfer/base.py +++ b/flox/runtime/transfer/base.py @@ -7,14 +7,14 @@ class BaseTransfer: # node_state: NodeState | dict[str, Any] | None, # node_idx: NodeID | None, # node_kind: NodeKind | None, - # state_dict: StateDict | None, + # params: Params | None, # history: DataFrame | None, # ) -> Result: # return JobResult( # node_state=node_state, # node_idx=node_idx, # node_kind=node_kind, - # state_dict=state_dict, + # params=params, # history=history, # ) diff --git a/flox/strategies/__init__.py b/flox/strategies/__init__.py index 7158212..ba12bbe 100644 --- a/flox/strategies/__init__.py +++ b/flox/strategies/__init__.py @@ -1,13 +1,75 @@ +""" +```mermaid +classDiagram + +class ClientStrategy { + <> + select_worker_nodes(self, state, children, seed) Iterable[Node] +} +``` +""" + from flox.strategies.aggregator import AggregatorStrategy from flox.strategies.client import ClientStrategy +from flox.strategies.impl.default import ( + DefaultClientStrategy, + DefaultAggregatorStrategy, + DefaultWorkerStrategy, + DefaultTrainerStrategy, +) from flox.strategies.strategy import Strategy from flox.strategies.trainer import TrainerStrategy from flox.strategies.worker import WorkerStrategy + +DefaultStrategy = Strategy( + DefaultClientStrategy(), + DefaultAggregatorStrategy(), + DefaultWorkerStrategy(), + DefaultTrainerStrategy(), +) + + +def load_strategy(strategy_name: str, **kwargs) -> Strategy: + """ + Loads the strategy identified by the ``strategy_name`` argument. + + Notes: + The argument `strategy_name` is *not* case-sensitive. Any value for that will be + lower-cased via `strategy_name.lower()`. + + Raises: + ValueError: in the event that the provided `strategy_name` is not supported. + + Args: + strategy_name (str): The name of the strategy to be loaded. + **kwargs: Arguments that are passed into the corresponding ``Strategy`` class. + + Returns: + An initialized instance of the specified ``Strategy``. + """ + assert isinstance(strategy_name, str), "`strategy_name` must be a string." + match strategy_name.lower(): + case "default": + return DefaultStrategy + case "fedavg" | "fed-avg": + from flox.strategies.impl.fedavg import FedAvg + + return FedAvg(**kwargs) + case "fedsgd" | "fed-sgd": + from flox.strategies.impl.fedsgd import FedSGD + + return FedSGD(**kwargs) + case _: + raise ValueError(f"Strategy '{strategy_name}' is not recognized.") + + __all__ = [ "Strategy", "ClientStrategy", "AggregatorStrategy", "WorkerStrategy", "TrainerStrategy", + "DefaultStrategy", + "load_strategy", ] diff --git a/flox/strategies/aggregator.py b/flox/strategies/aggregator.py index 09b5abe..f23623b 100644 --- a/flox/strategies/aggregator.py +++ b/flox/strategies/aggregator.py @@ -1,12 +1,24 @@ +from __future__ import annotations + import typing as t +if t.TYPE_CHECKING: + from flox.flock import AggrState, NodeID, NodeState + from flox.nn.typing import Params + class AggregatorStrategy(t.Protocol): def round_start(self): - pass + _ = 0 - def aggregate_params(self): + def aggregate_params( + self, + state: AggrState, + children_states: t.Mapping[NodeID, NodeState], + children_state_dicts: t.Mapping[NodeID, Params], + **kwargs, + ) -> Params: pass def round_end(self): - pass + _ = 0 diff --git a/flox/strategies/client.py b/flox/strategies/client.py index 21d66ac..a7d32e8 100644 --- a/flox/strategies/client.py +++ b/flox/strategies/client.py @@ -2,11 +2,11 @@ class ClientStrategy(t.Protocol): - def get_node_statuses(self): - pass + # def get_node_statuses(self): + # pass - def select_worker_nodes(self): - pass + def select_worker_nodes(self, state, children, seed): + return children - def before_share_params(self): - pass + # def before_share_params(self): + # pass diff --git a/flox/strategies_depr/commons/__init__.py b/flox/strategies/commons/__init__.py similarity index 100% rename from flox/strategies_depr/commons/__init__.py rename to flox/strategies/commons/__init__.py diff --git a/flox/strategies_depr/commons/averaging.py b/flox/strategies/commons/averaging.py similarity index 75% rename from flox/strategies_depr/commons/averaging.py rename to flox/strategies/commons/averaging.py index 5e5d369..c0971f1 100644 --- a/flox/strategies_depr/commons/averaging.py +++ b/flox/strategies/commons/averaging.py @@ -8,22 +8,22 @@ if typing.TYPE_CHECKING: from collections.abc import Mapping from flox.flock import NodeID - from flox.nn.typing import StateDict + from flox.nn.typing import Params def average_state_dicts( - state_dicts: Mapping[NodeID, StateDict], + state_dicts: Mapping[NodeID, Params], weights: Mapping[NodeID, float] | None = None, -) -> StateDict: - """Averages the parameters given by ``global_module.state_dict()`` from a set of ``FlockNodes``. +) -> Params: + """Averages the parameters given by ``global_module.params()`` from a set of ``FlockNodes``. Args: - state_dicts (dict[NodeID, StateDict]): The global_module state dicts of each FlockNode to average. + state_dicts (dict[NodeID, Params]): The global_module state dicts of each FlockNode to average. weights (dict[NodeID, float] | None): The weights for each ``FlockNode`` used do weighted averaging. If no weights are provided (i.e., `weights=None`), then standard averaging is done. Returns: - Averaged weights as a ``StateDict``. + Averaged weights as a ``Params``. """ num_nodes = len(state_dicts) weight_sum = None if weights is None else numpy.sum(list(weights.values())) diff --git a/flox/strategies_depr/commons/worker_selection.py b/flox/strategies/commons/worker_selection.py similarity index 100% rename from flox/strategies_depr/commons/worker_selection.py rename to flox/strategies/commons/worker_selection.py diff --git a/flox/strategies/impl/__init__.py b/flox/strategies/impl/__init__.py new file mode 100644 index 0000000..12b2b9d --- /dev/null +++ b/flox/strategies/impl/__init__.py @@ -0,0 +1,12 @@ +""" +This submodule contains default implementations of well-known federated learning strategies (or algorithms) +from the academic literature. These include: + +- [`FedSGD`][flox.strategies.impl.fedSGD.FedSGD] +- [`FedAvg`][flox.strategies.impl.fedavg.FedAvg] +- [`FedProx`][flox.strategies.impl.fedprox.FedProx] + +## Any missing strategies you have in mind? :huh: +Please feel free to implement them and contribute them to the repository with a pull request on +[GitHub](https://github.com/nathaniel-hudson/FLoX)! :smile: +""" diff --git a/flox/strategies/impl/default.py b/flox/strategies/impl/default.py new file mode 100644 index 0000000..062f985 --- /dev/null +++ b/flox/strategies/impl/default.py @@ -0,0 +1,24 @@ +from flox.strategies.aggregator import AggregatorStrategy +from flox.strategies.client import ClientStrategy +from flox.strategies.trainer import TrainerStrategy +from flox.strategies.worker import WorkerStrategy + + +class DefaultClientStrategy(ClientStrategy): + def __init__(self): + super().__init__() + + +class DefaultAggregatorStrategy(AggregatorStrategy): + def __init__(self): + super().__init__() + + +class DefaultWorkerStrategy(WorkerStrategy): + def __init__(self): + super().__init__() + + +class DefaultTrainerStrategy(TrainerStrategy): + def __init__(self): + super().__init__() diff --git a/flox/strategies/impl/fedavg.py b/flox/strategies/impl/fedavg.py new file mode 100644 index 0000000..143c0a6 --- /dev/null +++ b/flox/strategies/impl/fedavg.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import typing as t + +from flox.strategies import ( + AggregatorStrategy, + Strategy, + WorkerStrategy, +) +from flox.strategies.commons.averaging import average_state_dicts +from flox.strategies.impl.fedsgd import FedSGDClient + +if t.TYPE_CHECKING: + from flox.flock import AggrState, NodeID, NodeState, WorkerState + from flox.nn.typing import Params + + +class FedAvgAggr(AggregatorStrategy): + def aggregate_params( + self, + state: AggrState, + children_states: t.Mapping[NodeID, NodeState], + children_state_dicts: t.Mapping[NodeID, Params], + **kwargs, + ) -> Params: + """Performs a weighted average of the model parameters returned by the child nodes. + + The average is done by: + + $$ + w^{t} \\triangleq \\sum_{k=1}^{K} \\frac{n_{k}}{n} w_{k}^{t} + $$ + + where $n_{k}$ is the number of data items at worker $k$ (and $n \\triangleq \\sum_{k} n_{k}$), + $w^{t}$ is the aggregated model parameters, $K$ is the number of returned + model updates, $t$ is the current round, and $w_{k}^{t}$ is the returned model + updates from child $k$ at round $t$. + + Args: + state (AggrState): ... + children_states (t.Mapping[NodeID, NodeState]): ... + children_state_dicts (t.Mapping[NodeID, Params]): ... + **kwargs: ... + + Returns: + The averaged parameters. + """ + weights = {} + for node, child_state in children_states.items(): + weights[node] = child_state["num_data_samples"] + state["num_data_samples"] = sum(weights.values()) + return average_state_dicts(children_state_dicts, weights=weights) + + +class FedAvgWorker(WorkerStrategy): + def before_training( + self, state: WorkerState, data: t.Any + ) -> tuple[WorkerState, t.Any]: + state["num_data_samples"] = len(data) + return state, data + + +class FedAvg(Strategy): + """Implementation of the Federated Averaging algorithm. + + This algorithm extends ``FedSGD`` and differs from it by performing a weighted + average based on the number of data samples each (sibling) worker has. Worker + selection is done randomly, same as ``FedSGD``. + + References: + McMahan, Brendan, et al. "Communication-efficient learning of deep networks + from decentralized data." Artificial intelligence and statistics. PMLR, 2017. + """ + + def __init__( + self, + participation: float = 1.0, + probabilistic: bool = False, + always_include_child_aggregators: bool = False, + ): + super().__init__( + client_strategy=FedSGDClient( + participation, + probabilistic, + always_include_child_aggregators, + ), + aggr_strategy=FedAvgAggr(), + worker_strategy=FedAvgWorker(), + ) diff --git a/flox/strategies_depr/registry/fedprox.py b/flox/strategies/impl/fedprox.py similarity index 60% rename from flox/strategies_depr/registry/fedprox.py rename to flox/strategies/impl/fedprox.py index 8d00d58..0c88c08 100644 --- a/flox/strategies_depr/registry/fedprox.py +++ b/flox/strategies/impl/fedprox.py @@ -1,61 +1,23 @@ -import torch - -from flox.flock.states import WorkerState -from flox.strategies_depr import FedAvg +from __future__ import annotations +import typing as t -class FedProx(FedAvg): - """Implementation of FedAvg with Proximal Term. - - This strategy extends ``FedAvg`` and differs from it by computing a "proximal term" - and adding it to the computed loss during the training step before doing backpropagation. - This proximal term is the norm difference between the parameters of the global model - and the worker's locally-updated model. This proximal term is used to make aggregation - less sensitive to harshly heterogeneous (i.e., non-iid) data distributions across workers. +import torch - More information on the proximal term and its definition can be found in the docstring - for ``FedProx.wrk_after_train_step()`` and in the reference below. +from flox.strategies import Strategy, TrainerStrategy +from flox.strategies.impl.fedavg import FedAvgAggr, FedAvgWorker +from flox.strategies.impl.fedsgd import FedSGDClient - References: - Li, Tian, et al. "Federated optimization in heterogeneous networks." Proceedings of - Machine learning and systems 2 (2020): 429-450. - """ +if t.TYPE_CHECKING: + from flox.flock import WorkerState + from flox.nn.typing import Loss - def __init__( - self, - mu: float = 0.3, - participation: float = 1.0, - probabilistic: bool = False, - always_include_child_aggregators: bool = True, - seed: int | None = None, - ): - """ - Args: - mu (float): Multiplier that weights the importance of the proximal term. If `mu=0` then - ``FedProx`` reduces to ``FedAvg``. - participation (float): Participation rate for random worker selection. - probabilistic (bool): Probabilistically chooses workers if True; otherwise will always - select `max(1, max_workers * participation)` workers. - always_include_child_aggregators (bool): If True, Will always include child nodes that are - aggregators; if False, then they are included at random. - seed (int): Random seed. - """ - super().__init__( - participation, - probabilistic, - always_include_child_aggregators, - seed, - ) +class FedProxTrainer(TrainerStrategy): + def __init__(self, mu: float = 0.3): self.mu = mu - def wrk_after_train_step( - self, - state: WorkerState, - loss: torch.Tensor, - *args, - **kwargs, - ) -> torch.Tensor: + def before_backprop(self, state: WorkerState, loss: Loss) -> Loss: """ Adds the proximal term before the optimization step during local training to minimize the following objective: @@ -67,14 +29,14 @@ def wrk_after_train_step( Args: state (WorkerState): - loss (torch.Tensor): + loss (Loss): **kwargs (): Returns: - + Loss with the proximal term added to it. """ - global_model = state.pre_local_train_model - local_model = state.post_local_train_model + global_model = state.global_model + local_model = state.local_model assert global_model is not None assert local_model is not None @@ -90,3 +52,39 @@ def wrk_after_train_step( proximal_term = (self.mu / 2) * norm loss += proximal_term return loss + + +class FedProx(Strategy): + """Implementation of FedAvg with Proximal Term. + + This strategy extends ``FedAvg`` and differs from it by computing a "proximal term" + and adding it to the computed loss during the training step before doing backpropagation. + This proximal term is the norm difference between the parameters of the global model + and the worker's locally-updated model. This proximal term is used to make aggregation + less sensitive to harshly heterogeneous (i.e., non-iid) data distributions across workers. + + More information on the proximal term and its definition can be found in the docstring + for ``FedProx.wrk_after_train_step()`` and in the reference below. + + References: + Li, Tian, et al. "Federated optimization in heterogeneous networks." Proceedings of + Machine learning and systems 2 (2020): 429-450. + """ + + def __init__( + self, + mu: float = 0.3, + participation: float = 1.0, + probabilistic: bool = False, + always_include_child_aggregators: bool = False, + ): + super().__init__( + client_strategy=FedSGDClient( + participation, + probabilistic, + always_include_child_aggregators, + ), + aggr_strategy=FedAvgAggr(), + worker_strategy=FedAvgWorker(), + trainer_strategy=FedProxTrainer(mu=mu), + ) diff --git a/flox/strategies/impl/fedsgd.py b/flox/strategies/impl/fedsgd.py new file mode 100644 index 0000000..9beeb61 --- /dev/null +++ b/flox/strategies/impl/fedsgd.py @@ -0,0 +1,85 @@ +import typing as t + +from flox.flock import NodeID, NodeState, AggrState +from flox.nn.typing import Params +from flox.strategies import Strategy, AggregatorStrategy, ClientStrategy +from flox.strategies.commons.averaging import average_state_dicts +from flox.strategies.commons.worker_selection import random_worker_selection + + +class FedSGDClient(ClientStrategy): + """ + ... + """ + + def __init__( + self, + participation, + probabilistic, + always_include_child_aggregators: bool, + ): + self.participation = participation + self.probabilistic = probabilistic + self.always_include_child_aggregators = always_include_child_aggregators + + def select_worker_nodes(self, state, children, seed): + return random_worker_selection( + children, + participation=self.participation, + probabilistic=self.probabilistic, + always_include_child_aggregators=self.always_include_child_aggregators, + seed=seed, + ) + + +class FedSGDAggr(AggregatorStrategy): + """ + ... + """ + + def aggregate_params( + self, + state: AggrState, + children_states: t.Mapping[NodeID, NodeState], + children_state_dicts: t.Mapping[NodeID, Params], + **kwargs, + ) -> Params: + """Performs a simple average of the model parameters returned by the child nodes. + + The average is done by: + + $$ + w^{t} \\triangleq \\frac{1}{K} \\sum_{k=1}^{K} w_{k}^{t} + $$ + + where $w^{t}$ is the aggregated model parameters, $K$ is the number of returned + model updates, $t$ is the current round, and $w_{k}^{t}$ is the returned model + updates from child $k$ at round $t$. + + Args: + state (AggrState): ... + children_states (t.Mapping[NodeID, NodeState]): ... + children_state_dicts (t.Mapping[NodeID, Params]): ... + **kwargs: ... + + Returns: + The averaged parameters. + """ + return average_state_dicts(children_state_dicts, weights=None) + + +class FedSGD(Strategy): + def __init__( + self, + participation: float = 1.0, + probabilistic: bool = False, + always_include_child_aggregators: bool = True, + ): + super().__init__( + client_strategy=FedSGDClient( + participation, + probabilistic, + always_include_child_aggregators, + ), + aggr_strategy=FedSGDAggr(), + ) diff --git a/flox/strategies/strategy.py b/flox/strategies/strategy.py index 522c382..14ec09d 100644 --- a/flox/strategies/strategy.py +++ b/flox/strategies/strategy.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import typing as t +from dataclasses import dataclass, field from flox.strategies.aggregator import AggregatorStrategy from flox.strategies.client import ClientStrategy @@ -6,16 +9,66 @@ from flox.strategies.worker import WorkerStrategy -class Strategy(t.NamedTuple): +class DefaultClientStrategy(ClientStrategy): + def __init__(self): + super().__init__() + + +class DefaultAggregatorStrategy(AggregatorStrategy): + def __init__(self): + super().__init__() + + +class DefaultWorkerStrategy(WorkerStrategy): + def __init__(self): + super().__init__() + + +class DefaultTrainerStrategy(TrainerStrategy): + def __init__(self): + super().__init__() + + +@dataclass(frozen=True, repr=False) +class Strategy: """ - A strategy... + A ``Strategy`` implementation is made up of a set of implementations for strategies on each part of the + topology during execution. """ - client_strategy: ClientStrategy | None = None - """...""" - aggr_strategy: AggregatorStrategy | None = None - """...""" - worker_strategy: WorkerStrategy | None = None - """...""" - trainer_strategy: TrainerStrategy | None = None - """...""" + client_strategy: ClientStrategy = field(default=DefaultClientStrategy) + """Implementation of callbacks specific to the CLIENT node.""" + aggr_strategy: AggregatorStrategy = field(default=DefaultAggregatorStrategy) + """Implementation of callbacks specific to the AGGREGATOR nodes.""" + worker_strategy: WorkerStrategy = field(default_factory=DefaultWorkerStrategy) + """Implementation of callbacks specific to the WORKER nodes.""" + trainer_strategy: TrainerStrategy = field(default_factory=DefaultTrainerStrategy) + """Implementation of callbacks specific to the training loop on the worker nodes.""" + + # def __post_init__(self): + # if self.client_strategy is not None: + # self.client_strategy + + def __repr__(self): + return str(self) + + def __str__(self) -> str: + name = self.__class__.__name__ + inner = ", ".join( + [ + f"{strategy_key}={strategy_value.__class__.__name__}" + for (strategy_key, strategy_value) in iter(self) + if strategy_value is not None + ] + ) + return f"{name}({inner})" + + def __iter__(self) -> t.Iterator[tuple[str, t.Any]]: + strategies = ( + ("client_strategy", self.client_strategy), + ("aggr_strategy", self.aggr_strategy), + ("worker_strategy", self.worker_strategy), + ("trainer_strategy", self.trainer_strategy), + ) + for strategy_key, strategy_value in strategies: + yield strategy_key, strategy_value diff --git a/flox/strategies/trainer.py b/flox/strategies/trainer.py index afedee0..e75a91d 100644 --- a/flox/strategies/trainer.py +++ b/flox/strategies/trainer.py @@ -1,16 +1,18 @@ -import typing as t - -import torch +from __future__ import annotations -Loss: t.TypeAlias = torch.Tensor +import typing as t if t.TYPE_CHECKING: - pass + from flox.flock import WorkerState + from flox.nn.typing import Loss class TrainerStrategy(t.Protocol): - def before_backprop(self, loss: Loss) -> Loss: - pass + def trainer_kwargs(self) -> dict[str, t.Any]: + return {} + + def before_backprop(self, state: WorkerState, loss: Loss) -> Loss: + return loss - def after_backprop(self, loss: Loss) -> Loss: - pass + def after_backprop(self, state: WorkerState, loss: Loss) -> Loss: + return loss diff --git a/flox/strategies/worker.py b/flox/strategies/worker.py index cb5617e..cd6de6f 100644 --- a/flox/strategies/worker.py +++ b/flox/strategies/worker.py @@ -1,20 +1,24 @@ +from __future__ import annotations + import typing as t -from flox.flock.states import NodeState +from flox.flock.states import WorkerState if t.TYPE_CHECKING: - pass + from flox.runtime import JobResult class WorkerStrategy(t.Protocol): - def work_start(self): - pass + def work_start(self, state: WorkerState) -> WorkerState: + return state - def before_training(self, state: NodeState, data: t.Any): - pass + def before_training( + self, state: WorkerState, data: t.Any + ) -> tuple[WorkerState, t.Any]: + return state, data - def after_training(self, node_state: NodeState): - pass + def after_training(self, state: WorkerState) -> WorkerState: + return state - def work_end(self): - pass + def work_end(self, result: JobResult) -> JobResult: + return result diff --git a/flox/strategies_depr/__init__.py b/flox/strategies_depr/__init__.py deleted file mode 100644 index 643be98..0000000 --- a/flox/strategies_depr/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -In FLoX, a `Strategy` is used to define the logic for a specific Federated Learning solution. -""" - -from flox.strategies_depr.base import Strategy -from flox.strategies_depr.registry.fedavg import FedAvg -from flox.strategies_depr.registry.fedprox import FedProx -from flox.strategies_depr.registry.fedsgd import FedSGD - -__all__ = ["Strategy", "FedSGD", "FedAvg", "FedProx"] diff --git a/flox/strategies_depr/base.py b/flox/strategies_depr/base.py deleted file mode 100644 index 69f33cb..0000000 --- a/flox/strategies_depr/base.py +++ /dev/null @@ -1,197 +0,0 @@ -from __future__ import annotations - -import abc -import typing - -from flox.nn import FloxModule - -if typing.TYPE_CHECKING: - import torch - - from typing import Iterable, MutableMapping, TypeAlias - from flox.flock import FlockNode, NodeID - from flox.flock.states import WorkerState, AggrState, NodeState - from flox.nn.typing import StateDict - - Loss: TypeAlias = torch.Tensor - - -class Strategy: - """Base class for the logical blocks of a FL process. - - A ``Strategy`` in FLoX is used to implement the logic of an FL process. A ``Strategy`` - provides a number of callbacks which can be overridden to inject pieces of logic - throughout the FL process. Some of these callbacks are run on the aggregator nodes - while others are run on the worker nodes. - - It is _**highly**_ encouraged that you read - [What Do Strategies Do](/getting_started/strategies_depr/what/) to better understand how - the callbacks included in a Strategy interact with one another and when they are run - in an FL process. - """ - - __metaclass__ = abc.ABCMeta - - registry: MutableMapping[str, type["Strategy"]] = {} - """...""" - - def __new__(cls, *args, **kwargs): - if cls is Strategy: - raise TypeError(f"Abstract class {cls.__name__} cannot be instantiated.") - return super().__new__(cls, *args, **kwargs) - - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - cls.registry[cls.__name__.lower()] = cls - - @classmethod - def get_strategy(cls, name: str) -> type["Strategy"]: - """ - Pulls a strategy class implementation from the registry by its name. - - Notes: - All names are lower-cased (e.g., the name for `FedAvg` is "fedavg"). Thus, any - provided argument for `name` is lower-cased via `name = name.lower()`. - - Args: - name (str): The name of the strategy implementation to pull from the registry. - - Returns: - Strategy class. - """ - name = name.lower() - if name in cls.registry: - return cls.registry[name] - else: - raise KeyError(f"Strategy name ({name}) is not in the Strategy registry.") - - #################################################################################### - # CLIENT CALLBACKS. # - #################################################################################### - - def cli_get_node_statuses(self): - """ - Followup callback upon getting status updates from all of the nodes in the Flock. - """ - - def cli_worker_selection( - self, state: AggrState, children: Iterable[FlockNode], **kwargs - ) -> Iterable[FlockNode]: - """ - - Args: - state (): - children (): - **kwargs (): - - Returns: - List of selected nodes that are children of the aggregator. - """ - return children - - def cli_before_share_params( - self, state: AggrState, state_dict: StateDict, **kwargs - ) -> StateDict: - """Callback before sharing parameters to child nodes. - - This is mostly done is modify the global model's StateDict. This can be done - to encrypt the model parameters, apply noise, personalize, etc. - - Args: - state (AggrState): The current state of the aggregator. - state_dict (StateDict): The global model's current StateDict - (i.e., parameters) before sharing with workers. - - Returns: - The global global_module StateDict. - """ - return state_dict - - #################################################################################### - # AGGREGATOR CALLBACKS. # - #################################################################################### - - def agg_before_round(self, state: AggrState) -> None: - """ - Some process to run at the start of a round. - - Args: - state (AggrState): The current state of the Aggregator FloxNode. - """ - raise NotImplementedError - - def agg_param_aggregation( - self, - state: AggrState, - children_states: MutableMapping[NodeID, NodeState], - children_state_dicts: MutableMapping[NodeID, StateDict], - **kwargs, - ) -> StateDict: - """ - - Args: - state (AggrState): - children_states (Mapping[NodeID, NodeState]): - children_state_dicts (Mapping[NodeID, NodeState]): - **kwargs (): - - Returns: - StateDict - """ - raise NotImplementedError - - #################################################################################### - # WORKER CALLBACKS. # - #################################################################################### - - def wrk_on_recv_params(self, state: WorkerState, params: StateDict, **kwargs): - """ - - Args: - state (): - params (): - **kwargs (): - - Returns: - - """ - return params - - def wrk_before_train_step(self, state: WorkerState, **kwargs): - """ - - Args: - state (): - **kwargs (): - - Returns: - - """ - raise NotImplementedError() - - def wrk_after_train_step(self, state: WorkerState, loss: Loss, **kwargs) -> Loss: - """ - - Args: - state (): - loss (): - **kwargs (): - - Returns: - - """ - return loss - - def wrk_before_submit_params(self, state: WorkerState, **kwargs) -> StateDict: - """ - - Args: - state (): - **kwargs (): - - Returns: - - """ - post_train_model = state.post_local_train_model - assert isinstance(post_train_model, FloxModule) - return post_train_model.state_dict() diff --git a/flox/strategies_depr/registry/__init__.py b/flox/strategies_depr/registry/__init__.py deleted file mode 100644 index cc482a4..0000000 --- a/flox/strategies_depr/registry/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -""" -This is a collection of already-implemented Strategies that can be used entirely on their own -or as starting points to define novel Strategies. The already-implemented strategies_depr include -some of the standard Federated Learning solutions found in the academic literature. -""" diff --git a/flox/strategies_depr/registry/fedadagrad.py b/flox/strategies_depr/registry/fedadagrad.py deleted file mode 100644 index 76aaf71..0000000 --- a/flox/strategies_depr/registry/fedadagrad.py +++ /dev/null @@ -1,7 +0,0 @@ -from flox.strategies_depr.registry.fedopt import FedOpt - - -class FedAdaGrad(FedOpt): - def __init__(self): - super().__init__() - # TODO diff --git a/flox/strategies_depr/registry/fedadam.py b/flox/strategies_depr/registry/fedadam.py deleted file mode 100644 index e964a44..0000000 --- a/flox/strategies_depr/registry/fedadam.py +++ /dev/null @@ -1,7 +0,0 @@ -from flox.strategies_depr.registry.fedopt import FedOpt - - -class FedAdam(FedOpt): - def __init__(self): - super().__init__() - # TODO diff --git a/flox/strategies_depr/registry/fedavg.py b/flox/strategies_depr/registry/fedavg.py deleted file mode 100644 index ab39416..0000000 --- a/flox/strategies_depr/registry/fedavg.py +++ /dev/null @@ -1,66 +0,0 @@ -from __future__ import annotations - -import typing - -from flox.strategies_depr.commons.averaging import average_state_dicts -from flox.strategies_depr.registry.fedsgd import FedSGD - -if typing.TYPE_CHECKING: - from collections.abc import Mapping - - from flox.flock import NodeID - from flox.flock.states import AggrState, WorkerState, NodeState - from flox.nn.typing import StateDict - - -class FedAvg(FedSGD): - """Implementation of the Federated Averaging algorithm. - - This algorithm extends ``FedSGD`` and differs from it by performing a weighted - average based on the number of data samples each (sibling) worker has. Worker - selection is done randomly, same as ``FedSGD``. - - References: - McMahan, Brendan, et al. "Communication-efficient learning of deep networks - from decentralized data." Artificial intelligence and statistics. PMLR, 2017. - """ - - def __init__( - self, - participation: float = 1.0, - probabilistic: bool = True, - always_include_child_aggregators: bool = True, - seed: int | None = None, - ): - """ - - Args: - participation (float): Participation rate for random worker selection. - probabilistic (bool): Probabilistically chooses workers if True; otherwise will always - select `max(1, max_workers * participation)` workers. - always_include_child_aggregators (bool): If True, Will always include child nodes that are - aggregators; if False, then they are included at random. - seed (int): Random seed. - """ - super().__init__( - participation, probabilistic, always_include_child_aggregators, seed - ) - - def wrk_before_train_step(self, state: WorkerState, **kwargs): - if "dataset" not in kwargs: - print(**kwargs) - raise ValueError("`dataset` must be provided") - state["num_data_samples"] = len(kwargs["dataset"]) - - def agg_param_aggregation( - self, - state: AggrState, - children_states: Mapping[NodeID, NodeState], - children_state_dicts: Mapping[NodeID, StateDict], - **kwargs, - ): - weights = {} - for node, child_state in children_states.items(): - weights[node] = child_state["num_data_samples"] - state["num_data_samples"] = sum(weights.values()) - return average_state_dicts(children_state_dicts, weights=weights) diff --git a/flox/strategies_depr/registry/fedopt.py b/flox/strategies_depr/registry/fedopt.py deleted file mode 100644 index cc249ef..0000000 --- a/flox/strategies_depr/registry/fedopt.py +++ /dev/null @@ -1,16 +0,0 @@ -from flox.strategies_depr.base import Strategy - - -class FedOpt(Strategy): - """ - Implementation of the FedOpt algorithm proposed by Reddi et al. (2020). It is implemented as a base class for the - three specializations presented in the referenced paper, namely, ``FedAdaGrad``, ``FedAdam``, and ``FedYogi``. - - References: - Reddi, S., Charles, Z., Zaheer, M., Garrett, Z., Rush, K., Konečný, J., ... & McMahan, H. B. (2020). - Adaptive federated optimization. arXiv preprint arXiv:2003.00295. - """ - - def __init__(self): - super().__init__() - # TODO diff --git a/flox/strategies_depr/registry/fedsgd.py b/flox/strategies_depr/registry/fedsgd.py deleted file mode 100644 index b72f46d..0000000 --- a/flox/strategies_depr/registry/fedsgd.py +++ /dev/null @@ -1,110 +0,0 @@ -from __future__ import annotations - -import typing - -from flox.flock import FlockNode, NodeID -from flox.flock.states import AggrState, NodeState -from flox.strategies_depr.base import Strategy -from flox.strategies_depr.commons.averaging import average_state_dicts -from flox.strategies_depr.commons.worker_selection import random_worker_selection - -if typing.TYPE_CHECKING: - from collections.abc import Iterable, Mapping - from flox.nn.typing import StateDict - - -class FedSGD(Strategy): - """Implementation of the Federated Stochastic Gradient Descent algorithm. - - In short, this algorithm randomly selects a subset of worker nodes and will - do a simple, unweighted average across the updates to the model parameters - (i.e., ``StateDict``). - - > **Reference:** - > - > McMahan, Brendan, et al. "Communication-efficient learning of deep networks - > from decentralized data." Artificial intelligence and statistics. PMLR, 2017. - """ - - def __init__( - self, - participation: float = 1.0, - probabilistic: bool = True, - always_include_child_aggregators: bool = True, - seed: int | None = None, - ): - """Initializes the FedSGD strategy with the desired parameters. - - Args: - participation (float): Fraction of the child nodes to be selected. - probabilistic (bool): If `True`, nodes are selected entirely probabilistically - rather than based on a fraction (`False`). As an example, consider you have - 10 children nodes to select from and `participation=0.5`. If `probabilistic=True`, - then exactly 5 children nodes *will* be selected; otherwise, then each child node - will be selected with probability 0.5. - always_include_child_aggregators (bool): If `True`, child aggregator nodes will always - be included; if `False`, then they will only be included if they are naturally - selected (similar to worker child nodes). - seed (int): Random seed. # TODO: Change this to standardized seeding format. - """ - super().__init__() - assert 0.0 <= participation <= 1.0 - self.participation = participation - self.probabilistic = probabilistic - self.always_include_child_aggregators = always_include_child_aggregators - self.seed = seed - - def agg_worker_selection( - self, - state: AggrState, - children: Iterable[FlockNode], - **kwargs, - ) -> list[FlockNode]: - """Performs a simple average of the model weights returned by the child nodes. - - The average is done by: - - $$ - w^{t} \\triangleq \\frac{1}{K} \\sum_{k=1}^{K} w_{k}^{t} - $$ - - where $w^{t}$ is the aggregated model weights, $K$ is the number of returned - model updates, $t$ is the current round, and $w_{k}^{t}$ is the returned model - updates from child $k$ at round $t$. - - Args: - state (AggrState): ... - children (list[FlockNode]): ... - **kwargs: ... - - Returns: - The selected children nodes. - """ - return random_worker_selection( - children, - participation=self.participation, - probabilistic=self.probabilistic, - always_include_child_aggregators=self.always_include_child_aggregators, - seed=self.seed, # TODO: Change this because it will always do the same thing as is. - ) - - def agg_param_aggregation( - self, - state: AggrState, - children_states: Mapping[NodeID, NodeState], - children_state_dicts: Mapping[NodeID, StateDict], - **kwargs, - ) -> StateDict: - """Runs simple, unweighted averaging of ``StateDict`` objects from each child node. - - Args: - state (AggrState): ... - children_states (dict[NodeID, NodeState]): ... - children_state_dicts (dict[NodeID, StateDict]): ... - *args: ... - **kwargs: ... - - Returns: - The averaged ``StateDict``. - """ - return average_state_dicts(children_state_dicts, weights=None) diff --git a/flox/strategies_depr/registry/fedyogi.py b/flox/strategies_depr/registry/fedyogi.py deleted file mode 100644 index f46be9a..0000000 --- a/flox/strategies_depr/registry/fedyogi.py +++ /dev/null @@ -1,7 +0,0 @@ -from flox.strategies_depr.registry.fedopt import FedOpt - - -class FedYogi(FedOpt): - def __init__(self): - super().__init__() - # TODO diff --git a/mkdocs.yml b/mkdocs.yml index 417c550..fd9bd0a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -40,7 +40,7 @@ extra: - icon: fontawesome/brands/slack link: https://join.slack.com/t/funcx/shared_invite/zt-gfeclqkz-RuKjkZkvj1t~eWvlnZV0KA name: Join the funcX Slack! - - icon: fontawesome/brands/github-alt + - icon: fontawesome/brands/github link: https://github.com/nathaniel-hudson/FLoX extra_css: @@ -55,7 +55,7 @@ theme: icon: #assets/logos/favicon-dark.svg logo: material/bird favicon: material/bird - repo: fontawesome/brands/github + repo: fontawesome/brands/github-alt # favicon: ... font: text: Open Sans diff --git a/print_strategy.py b/print_strategy.py new file mode 100644 index 0000000..7a26622 --- /dev/null +++ b/print_strategy.py @@ -0,0 +1,4 @@ +from flox.strategies.impl.fedavg import FedAvg + +strategy = FedAvg(1.0, True) +print(strategy)