From 0af3fdd71ff328dd1f3cf0308f7ac34f26217d1c Mon Sep 17 00:00:00 2001 From: Logan Alt Date: Tue, 9 Jul 2024 19:56:28 -0500 Subject: [PATCH 01/11] Strategies refactor --- flight/strategies/__init__.py | 18 +++ flight/strategies/aggr.py | 13 +- flight/strategies/base.py | 145 +++++++++++++----- flight/strategies/commons/__init__.py | 4 + flight/strategies/commons/averaging.py | 32 ++++ flight/strategies/commons/worker_selection.py | 57 +++++++ flight/strategies/coord.py | 8 +- flight/strategies/impl/fedasync.py | 46 ++++++ flight/strategies/impl/fedavg.py | 0 flight/strategies/impl/fedprox.py | 0 flight/strategies/impl/fedsgd.py | 0 flight/strategies/trainer.py | 5 +- flight/strategies/worker.py | 15 +- tests/strategies/__init__.py | 0 tests/strategies/test_aggr.py | 51 ++++++ tests/strategies/test_coord.py | 8 + 16 files changed, 355 insertions(+), 47 deletions(-) create mode 100644 flight/strategies/commons/__init__.py create mode 100644 flight/strategies/commons/averaging.py create mode 100644 flight/strategies/commons/worker_selection.py create mode 100644 flight/strategies/impl/fedasync.py create mode 100644 flight/strategies/impl/fedavg.py create mode 100644 flight/strategies/impl/fedprox.py create mode 100644 flight/strategies/impl/fedsgd.py create mode 100644 tests/strategies/__init__.py create mode 100644 tests/strategies/test_aggr.py create mode 100644 tests/strategies/test_coord.py diff --git a/flight/strategies/__init__.py b/flight/strategies/__init__.py index e69de29..7117715 100644 --- a/flight/strategies/__init__.py +++ b/flight/strategies/__init__.py @@ -0,0 +1,18 @@ +from flight.strategies.aggr import AggrStrategy +from flight.strategies.base import Strategy, DefaultStrategy +from flight.strategies.coord import CoordStrategy +from flight.strategies.trainer import TrainerStrategy +from flight.strategies.worker import WorkerStrategy + + +def load_strategy(strategy_name: str, **kwargs) -> Strategy: + assert NotImplementedError + + +__all__ = [ + "AggrStrategy", + "Strategy", + "CoordStrategy", + "TrainerStrategy", + "WorkerStrategy", +] diff --git a/flight/strategies/aggr.py b/flight/strategies/aggr.py index 81c4f28..8ffa12f 100644 --- a/flight/strategies/aggr.py +++ b/flight/strategies/aggr.py @@ -1,14 +1,25 @@ +from __future__ import annotations + import typing as t if t.TYPE_CHECKING: Params: t.TypeAlias = t.Any + NodeState: t.TypeAlias = t.Any + NodeID: t.TypeAlias = t.Any +@t.runtime_checkable class AggrStrategy(t.Protocol): def start_round(self): pass - def aggregate_params(self) -> Params: + def aggregate_params( + self, + state: NodeState, + children_states: t.Mapping[NodeID, NodeState], + children_state_dicts: t.Mapping[NodeID, Params], + **kwargs, + ) -> Params: pass def end_round(self): diff --git a/flight/strategies/base.py b/flight/strategies/base.py index f6679ed..ddd5679 100644 --- a/flight/strategies/base.py +++ b/flight/strategies/base.py @@ -1,35 +1,110 @@ -import functools -import typing as t - -import pydantic as pyd - - -@pyd.dataclasses.dataclass(frozen=True, repr=False) -class Strategy: - coord_strategy: str = pyd.field() - aggr_strategy: str = pyd.field() - worker_strategy: str = pyd.field() - trainer_strategy: str = pyd.field() - - def __iter__(self) -> t.Iterator[tuple[str, t.Any]]: - yield from ( - ("coord_strategy", self.coord_strategy), - ("aggr_strategy", self.aggr_strategy), - ("worker_strategy", self.worker_strategy), - ("trainer_strategy", self.trainer_strategy), - ) - - def __repr__(self) -> str: - return str(self) - - @functools.cached_property - def __str__(self) -> str: - name = self.__class__.__name__ - inner = ", ".join( - [ - f"{strategy_key}={strategy_value.__class__.__name__}" - for (strategy_key, strategy_value) in iter(self) - if strategy_value is not None - ] - ) - return f"{name}({inner})" +from __future__ import annotations + +import pydantic as pyd +import typing as t +import functools + +from flight.strategies.commons.averaging import average_state_dicts + +if t.TYPE_CHECKING: + import torch + from numpy.random import Generator + + NodeState: t.TypeAlias = t.Any + NodeID: t.TypeAlias = int | str + Params: t.TypeAlias = t.Any + Loss: t.TypeAlias = torch.Tensor + + from flight.federation.topologies.node import Node + from flight.federation.jobs.result import Result + + +class DefaultCoordStrategy: + def select_workers( + self, state: NodeState, children: t.Iterable[Node], rng: Generator + ) -> t.Sequence[Node]: + return children + + +class DefaultAggrStrategy: + def start_round(self): + pass + + def aggregate_params( + self, + state: NodeState, + children: t.Mapping[NodeID, NodeState], + children_state_dicts: t.Mapping[NodeID, Params], + **kwargs, + ) -> Params: + return average_state_dicts(children_state_dicts, weights=None) + + def end_round(self): + pass + + +class DefaultWorkerStrategy: + def start_work(self, state: NodeState) -> NodeState: + return state + + def before_training( + self, state: NodeState, data: Params + ) -> tuple[NodeState, Params]: + return state, data + + def after_training( + self, state: NodeState, optimizer: torch.optim.Optimizer + ) -> NodeState: + return state + + def end_work(self, result: Result) -> Result: + return result + + +class DefaultTrainerStrategy: + def before_backprop(self, state: NodeState, loss: Loss) -> Loss: + return loss + + def after_backprop(self, state: NodeState, loss: Loss) -> Loss: + return loss + + +@pyd.dataclasses.dataclass(frozen=True, repr=False) +class Strategy: + coord_strategy: str = pyd.Field() + aggr_strategy: str = pyd.Field() + worker_strategy: str = pyd.Field() + trainer_strategy: str = pyd.Field() + + def __iter__(self) -> t.Iterator[tuple[str, t.Any]]: + yield from ( + ("coord_strategy", self.coord_strategy), + ("aggr_strategy", self.aggr_strategy), + ("worker_strategy", self.worker_strategy), + ("trainer_strategy", self.trainer_strategy), + ) + + def __repr__(self) -> str: + return str(self) + + @functools.cached_property + def __str__(self) -> str: + name = self.__class__.__name__ + inner = ", ".join( + [ + f"{strategy_key}={strategy_value.__class__.__name__}" + for (strategy_key, strategy_value) in iter(self) + if strategy_value is not None + ] + ) + return f"{name}({inner})" + + +class DefaultStrategy(Strategy): + def __init__(self) -> None: + super().__init__( + coord_strategy=DefaultCoordStrategy(), + aggr_strategy=DefaultAggrStrategy(), + worker_strategy=DefaultWorkerStrategy(), + trainer_strategy=DefaultTrainerStrategy(), + ) diff --git a/flight/strategies/commons/__init__.py b/flight/strategies/commons/__init__.py new file mode 100644 index 0000000..6318484 --- /dev/null +++ b/flight/strategies/commons/__init__.py @@ -0,0 +1,4 @@ +from flight.strategies.commons.averaging import average_state_dicts +from flight.strategies.commons.worker_selection import random_worker_selection + +__all__ = ['average_state_dicts', 'random_worker_selection'] \ No newline at end of file diff --git a/flight/strategies/commons/averaging.py b/flight/strategies/commons/averaging.py new file mode 100644 index 0000000..0d28980 --- /dev/null +++ b/flight/strategies/commons/averaging.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import typing as t + +import numpy +import torch + +if t.TYPE_CHECKING: + from collections.abc import Mapping + + NodeID: t.TypeAlias = t.Any + Params: t.TypeAlias = t.Any + + +def average_state_dicts( + state_dicts: Mapping[NodeID, Params], weights: Mapping[NodeID, float] | None = None +) -> Params: + num_nodes = len(state_dicts) + weight_sum = None if weights is None else numpy.sum(list(weights.values())) + + with torch.no_grad(): + avg_weights = {} + for node, state_dict in state_dicts.items(): + w = 1 / num_nodes if weights is None else weights[node] / weight_sum + for name, value in state_dict.items(): + value = w * torch.clone(value) + if name not in avg_weights: + avg_weights[name] = value + else: + avg_weights[name] += value + + return avg_weights diff --git a/flight/strategies/commons/worker_selection.py b/flight/strategies/commons/worker_selection.py new file mode 100644 index 0000000..056e36a --- /dev/null +++ b/flight/strategies/commons/worker_selection.py @@ -0,0 +1,57 @@ +from collections.abc import Iterable +from typing import cast + +from numpy import array +from numpy.random import RandomState +from numpy.typing import NDArray + +from flight.federation.topologies.node import Node +from flight.federation.topologies.node import NodeKind + + +def random_worker_selection( + children: Iterable[Node], + participation: float = 1.0, + probabilistic: bool = False, + always_include_child_aggregators: bool = True, + seed: int | None = None, +) -> list[Node]: + if probabilistic: + return prob_random_worker_selection( + children, participation, always_include_child_aggregators, seed + ) + return fixed_random_worker_selection(children, participation, seed) + + +def fixed_random_worker_selection( + children: Iterable[Node], participation: float = 1.0, seed: int | None = None +) -> list[Node]: + children = array(children) + rand_state = RandomState(seed) + num_selected = max(1, int(participation * len(list(children)))) + + achildren = cast(NDArray, children) + selected_children = rand_state.choice(achildren, size=num_selected, replace=False) + return list(selected_children) + + +def prob_random_worker_selection( + children: Iterable[Node], + participation: float = 1.0, + always_include_child_aggregators: bool = True, + seed: int | None = None, +) -> list[Node]: + rand_state = RandomState(seed) + selected_children = [] + for child in children: + if child.kind is NodeKind.WORKER and always_include_child_aggregators: + selected_children.append(child) + elif rand_state.uniform() <= participation: + selected_children.append(child) + + if len(selected_children) == 0: + achildren = cast(NDArray, children) + random_child = rand_state.choice(achildren) + selected_children.append(random_child) + + return selected_children diff --git a/flight/strategies/coord.py b/flight/strategies/coord.py index 7314e3d..0115681 100644 --- a/flight/strategies/coord.py +++ b/flight/strategies/coord.py @@ -5,13 +5,13 @@ if t.TYPE_CHECKING: from numpy.random import Generator - from ..federation.topologies.node import Node - - CoordState: t.TypeAlias = t.Any + from flight.federation.topologies.node import Node + NodeState: t.TypeAlias = t.Any +@t.runtime_checkable class CoordStrategy(t.Protocol): def select_workers( - self, state: CoordState, workers: t.Iterable[Node], rng: Generator + self, state: NodeState, workers: t.Iterable[Node], rng: Generator ) -> t.Sequence[Node]: pass diff --git a/flight/strategies/impl/fedasync.py b/flight/strategies/impl/fedasync.py new file mode 100644 index 0000000..d9b2e76 --- /dev/null +++ b/flight/strategies/impl/fedasync.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +import typing as t +from collections import OrderedDict + +from flight.strategies.base import DefaultAggrStrategy +from flight.strategies import Strategy + +if t.TYPE_CHECKING: + NodeState: t.TypeAlias = t.Any + NodeID: t.TypeAlias = t.Any + Params: t.TypeAlias = t.Any + +class FedAsyncAggr(DefaultAggrStrategy): + def __init__(self, alpha: float = 0.5): + assert 0.0 < alpha <= 1.0 + self.alpha = alpha + + def aggregate_params( + self, + state: NodeState, + children_states: t.Mapping[NodeID, NodeState], + children_state_dicts: t.Mapping[NodeID, Params], + **kwargs + ) -> Params: + last_updated = kwargs.get("last_updated_node", None) + assert last_updated is not None + assert isinstance(last_updated, int | str) + + global_model_params = state.global_model.state_dict() + last_updated_params = children_state_dicts[last_updated] + + aggr_params = [] + for param in global_model_params: + w0, w = ( + global_model_params[param].detach(), + last_updated_params[param].detach() + ) + aggr_w = w0 * (1 - self.alpha) + w * self.alpha + aggr_params.append((param, aggr_w)) + + return OrderedDict(aggr_params) + +class FedAsync(Strategy): + def __init__(self, alpha: float): + super().__init__(aggr_strategy=FedAsyncAggr(alpha)) \ No newline at end of file diff --git a/flight/strategies/impl/fedavg.py b/flight/strategies/impl/fedavg.py new file mode 100644 index 0000000..e69de29 diff --git a/flight/strategies/impl/fedprox.py b/flight/strategies/impl/fedprox.py new file mode 100644 index 0000000..e69de29 diff --git a/flight/strategies/impl/fedsgd.py b/flight/strategies/impl/fedsgd.py new file mode 100644 index 0000000..e69de29 diff --git a/flight/strategies/trainer.py b/flight/strategies/trainer.py index c0fd3cf..5b8b71b 100644 --- a/flight/strategies/trainer.py +++ b/flight/strategies/trainer.py @@ -5,12 +5,13 @@ if t.TYPE_CHECKING: import torch + NodeState: t.TypeAlias = t.Any Loss: t.TypeAlias = torch.Tensor class TrainerStrategy(t.Protocol): - def before_backprop(self, state, loss: Loss) -> Loss: + def before_backprop(self, state: NodeState, loss: Loss) -> Loss: pass - def after_backprop(self, state, loss: Loss) -> Loss: + def after_backprop(self, state: NodeState, loss: Loss) -> Loss: pass diff --git a/flight/strategies/worker.py b/flight/strategies/worker.py index c3b6853..9a35080 100644 --- a/flight/strategies/worker.py +++ b/flight/strategies/worker.py @@ -3,18 +3,23 @@ import typing as t if t.TYPE_CHECKING: - pass + import torch + from flight.federation.jobs.result import Result + NodeState: t.TypeAlias = t.Any +@t.runtime_checkable class WorkerStrategy(t.Protocol): - def start_work(self): + def start_work(self, state: NodeState) -> NodeState: pass - def before_training(self): + def before_training(self, state: NodeState, data: t.Any) -> tuple[NodeState, t.Any]: pass - def after_training(self): + def after_training( + self, state: NodeState, optimizer: torch.optim.Optimizer + ) -> NodeState: pass - def end_work(self): + def end_work(self, result: Result) -> Result: pass diff --git a/tests/strategies/__init__.py b/tests/strategies/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/strategies/test_aggr.py b/tests/strategies/test_aggr.py new file mode 100644 index 0000000..9eb956d --- /dev/null +++ b/tests/strategies/test_aggr.py @@ -0,0 +1,51 @@ +from flight.strategies import AggrStrategy +from flight.strategies.base import DefaultAggrStrategy + +import tensorflow as tf +import torch + + +def test_instance(): + default_aggr = DefaultAggrStrategy() + + assert isinstance(default_aggr, AggrStrategy) + + +def test_aggr_aggregate_params(): + default_aggr = DefaultAggrStrategy() + + state = "foo" + children = {1: "foo1", 2: "foo2"} + + children_state_dicts = { + 1: { + "train/loss": tf.convert_to_tensor(2.3, dtype=tf.float32), + "train/acc": tf.convert_to_tensor(1.2, dtype=tf.float32), + }, + 2: { + "train/loss": tf.convert_to_tensor(3.1, dtype=tf.float32), + "train/acc": tf.convert_to_tensor(1.4, dtype=tf.float32), + }, + } + + children_state_dicts_pt = { + key: { + sub_key: torch.tensor(value.numpy()) for sub_key, value in sub_dict.items() + } + for key, sub_dict in children_state_dicts.items() + } + + avg = default_aggr.aggregate_params(state, children, children_state_dicts_pt) + + assert isinstance(avg, dict) + + expected_avg = { + "train/loss": 2.7, + "train/acc": 1.3, + } + + epsilon = 1e-6 + for key, value in avg.items(): + expected = expected_avg[key] + + assert abs(expected - value) < epsilon diff --git a/tests/strategies/test_coord.py b/tests/strategies/test_coord.py new file mode 100644 index 0000000..0380d19 --- /dev/null +++ b/tests/strategies/test_coord.py @@ -0,0 +1,8 @@ +from flight.strategies import CoordStrategy +from flight.strategies.base import DefaultCoordStrategy + +def test_instance(): + default_coord = DefaultCoordStrategy() + + assert isinstance(default_coord, CoordStrategy) + From 06895d0e601c0ffa463506da9d6079e28c5eecc9 Mon Sep 17 00:00:00 2001 From: loganalt9 Date: Fri, 12 Jul 2024 18:32:43 -0400 Subject: [PATCH 02/11] Strategies refactor and partial tests --- flight/strategies/__init__.py | 36 ++++++++- flight/strategies/aggr.py | 23 +++++- flight/strategies/base.py | 43 +++++++---- flight/strategies/commons/__init__.py | 2 +- flight/strategies/commons/averaging.py | 11 ++- flight/strategies/commons/worker_selection.py | 28 +++---- flight/strategies/coord.py | 12 +++ flight/strategies/impl/__init__.py | 6 ++ flight/strategies/impl/fedasync.py | 32 +++++--- flight/strategies/impl/fedavg.py | 60 +++++++++++++++ flight/strategies/impl/fedprox.py | 67 +++++++++++++++++ flight/strategies/impl/fedsgd.py | 74 +++++++++++++++++++ flight/strategies/trainer.py | 24 +++++- flight/strategies/worker.py | 39 ++++++++++ tests/strategies/impl/__init__.py | 0 tests/strategies/impl/test_fedasync.py | 68 +++++++++++++++++ tests/strategies/impl/test_fedavg.py | 0 tests/strategies/impl/test_fedprox.py | 0 tests/strategies/impl/test_fedsgd.py | 0 tests/strategies/test_aggr.py | 30 ++++---- tests/strategies/test_coord.py | 63 ++++++++++++++++ tests/strategies/test_worker.py | 8 ++ 22 files changed, 560 insertions(+), 66 deletions(-) create mode 100644 tests/strategies/impl/__init__.py create mode 100644 tests/strategies/impl/test_fedasync.py create mode 100644 tests/strategies/impl/test_fedavg.py create mode 100644 tests/strategies/impl/test_fedprox.py create mode 100644 tests/strategies/impl/test_fedsgd.py create mode 100644 tests/strategies/test_worker.py diff --git a/flight/strategies/__init__.py b/flight/strategies/__init__.py index 7117715..8dac982 100644 --- a/flight/strategies/__init__.py +++ b/flight/strategies/__init__.py @@ -1,12 +1,44 @@ +import typing as t + +import torch + +Loss: t.TypeAlias = torch.tensor +Params: t.TypeAlias = dict[str, torch.Tensor] + from flight.strategies.aggr import AggrStrategy -from flight.strategies.base import Strategy, DefaultStrategy +from flight.strategies.base import DefaultStrategy, Strategy from flight.strategies.coord import CoordStrategy from flight.strategies.trainer import TrainerStrategy from flight.strategies.worker import WorkerStrategy def load_strategy(strategy_name: str, **kwargs) -> Strategy: - assert NotImplementedError + assert isinstance(strategy_name, str), "`strategy_name` must be a string." + match strategy_name.lower(): + case "default": + return DefaultStrategy() + + case "fedasync" | "fed-async": + from flight.strategies.impl.fedasync import FedAsync + + return FedAsync(**kwargs) + + case "fedavg" | "fed-avg": + from flight.strategies.impl.fedavg import FedAvg + + return FedAvg(**kwargs) + + case "fedprox" | "fed-prox": + from flight.strategies.impl.fedprox import FedProx + + return FedProx(**kwargs) + + case "fedsgd" | "fed-sgd": + from flight.strategies.impl.fedsgd import FedSGD + + return FedSGD(**kwargs) + case _: + raise ValueError(f"Strategy '{strategy_name}' is not recognized.") __all__ = [ diff --git a/flight/strategies/aggr.py b/flight/strategies/aggr.py index 8ffa12f..51c7217 100644 --- a/flight/strategies/aggr.py +++ b/flight/strategies/aggr.py @@ -3,14 +3,21 @@ import typing as t if t.TYPE_CHECKING: - Params: t.TypeAlias = t.Any + import torch + + from flight.strategies import Params + NodeState: t.TypeAlias = t.Any - NodeID: t.TypeAlias = t.Any + + from flight.federation.topologies.node import NodeID @t.runtime_checkable class AggrStrategy(t.Protocol): + """Template for all aggregator strategies, including those defined in Flight and those defined by Users.""" + def start_round(self): + """Callback to run at the start of a round.""" pass def aggregate_params( @@ -20,7 +27,19 @@ def aggregate_params( children_state_dicts: t.Mapping[NodeID, Params], **kwargs, ) -> Params: + """Callback that handles the model parameter aggregation step. + + Args: + state (NodeState): The state of the current aggregator node. + children_states (t.Mapping[NodeID, NodeState]): A mapping of the current aggregator node's children and their respective states. + children_state_dicts (t.Mapping[NodeID, Parmas]): The model parameters of the models to each respective child node. + **kwargs: Keyword arguments provided by users. + + Returns: + The aggregated parameters to update the model at the current aggregator. + """ pass def end_round(self): + """Callback to run at the end of a round.""" pass diff --git a/flight/strategies/base.py b/flight/strategies/base.py index ddd5679..2cbbe3c 100644 --- a/flight/strategies/base.py +++ b/flight/strategies/base.py @@ -1,29 +1,35 @@ from __future__ import annotations -import pydantic as pyd -import typing as t import functools +import typing as t +import pydantic as pyd + +from flight.strategies.aggr import AggrStrategy from flight.strategies.commons.averaging import average_state_dicts +from flight.strategies.coord import CoordStrategy +from flight.strategies.trainer import TrainerStrategy +from flight.strategies.worker import WorkerStrategy + +StrategyType: t.TypeAlias = ( + t.Any +) # WorkerStrategy | AggrStrategy | CoordStrategy | TrainerStrategy if t.TYPE_CHECKING: import torch from numpy.random import Generator NodeState: t.TypeAlias = t.Any - NodeID: t.TypeAlias = int | str - Params: t.TypeAlias = t.Any - Loss: t.TypeAlias = torch.Tensor - - from flight.federation.topologies.node import Node from flight.federation.jobs.result import Result + from flight.federation.topologies.node import Node, NodeID + from flight.strategies import Loss, Params class DefaultCoordStrategy: def select_workers( self, state: NodeState, children: t.Iterable[Node], rng: Generator ) -> t.Sequence[Node]: - return children + return list(children) class DefaultAggrStrategy: @@ -33,7 +39,7 @@ def start_round(self): def aggregate_params( self, state: NodeState, - children: t.Mapping[NodeID, NodeState], + children_states: t.Mapping[NodeID, NodeState], children_state_dicts: t.Mapping[NodeID, Params], **kwargs, ) -> Params: @@ -71,10 +77,19 @@ def after_backprop(self, state: NodeState, loss: Loss) -> Loss: @pyd.dataclasses.dataclass(frozen=True, repr=False) class Strategy: - coord_strategy: str = pyd.Field() - aggr_strategy: str = pyd.Field() - worker_strategy: str = pyd.Field() - trainer_strategy: str = pyd.Field() + """ + A 'Strategy' implementation is comprised of the four different type of implementations of strategies + to be used on the respective node types throughout the training process. + """ + + """Implementation of the specific callbacks for the coordinator node.""" + coord_strategy: StrategyType = pyd.Field() + """Implementation of the specific callbacks for the aggregator node(s).""" + aggr_strategy: StrategyType = pyd.Field() + """Implementation of the specific callbacks for the worker node(s).""" + worker_strategy: StrategyType = pyd.Field() + """Implementation of callbacks specific to the execution of the training loop on the worker node(s).""" + trainer_strategy: StrategyType = pyd.Field() def __iter__(self) -> t.Iterator[tuple[str, t.Any]]: yield from ( @@ -87,7 +102,7 @@ def __iter__(self) -> t.Iterator[tuple[str, t.Any]]: def __repr__(self) -> str: return str(self) - @functools.cached_property + # @functools.cached_property def __str__(self) -> str: name = self.__class__.__name__ inner = ", ".join( diff --git a/flight/strategies/commons/__init__.py b/flight/strategies/commons/__init__.py index 6318484..6e3e270 100644 --- a/flight/strategies/commons/__init__.py +++ b/flight/strategies/commons/__init__.py @@ -1,4 +1,4 @@ from flight.strategies.commons.averaging import average_state_dicts from flight.strategies.commons.worker_selection import random_worker_selection -__all__ = ['average_state_dicts', 'random_worker_selection'] \ No newline at end of file +__all__ = ["average_state_dicts", "random_worker_selection"] diff --git a/flight/strategies/commons/averaging.py b/flight/strategies/commons/averaging.py index 0d28980..ef1eca0 100644 --- a/flight/strategies/commons/averaging.py +++ b/flight/strategies/commons/averaging.py @@ -16,12 +16,19 @@ def average_state_dicts( state_dicts: Mapping[NodeID, Params], weights: Mapping[NodeID, float] | None = None ) -> Params: num_nodes = len(state_dicts) - weight_sum = None if weights is None else numpy.sum(list(weights.values())) + + if weights is not None: + weight_sum = numpy.sum(list(weights.values())) + else: + weight_sum = None with torch.no_grad(): avg_weights = {} for node, state_dict in state_dicts.items(): - w = 1 / num_nodes if weights is None else weights[node] / weight_sum + if weights is not None: + w = weights[node] / weight_sum + else: + w = 1 / num_nodes for name, value in state_dict.items(): value = w * torch.clone(value) if name not in avg_weights: diff --git a/flight/strategies/commons/worker_selection.py b/flight/strategies/commons/worker_selection.py index 056e36a..eb0f331 100644 --- a/flight/strategies/commons/worker_selection.py +++ b/flight/strategies/commons/worker_selection.py @@ -2,11 +2,10 @@ from typing import cast from numpy import array -from numpy.random import RandomState +from numpy.random import Generator, RandomState, default_rng from numpy.typing import NDArray -from flight.federation.topologies.node import Node -from flight.federation.topologies.node import NodeKind +from flight.federation.topologies.node import Node, NodeKind def random_worker_selection( @@ -14,44 +13,41 @@ def random_worker_selection( participation: float = 1.0, probabilistic: bool = False, always_include_child_aggregators: bool = True, - seed: int | None = None, + rng: Generator | None = None, ) -> list[Node]: + if rng is None: + rng = default_rng() if probabilistic: return prob_random_worker_selection( - children, participation, always_include_child_aggregators, seed + children, rng, participation, always_include_child_aggregators ) - return fixed_random_worker_selection(children, participation, seed) + return fixed_random_worker_selection(children, rng, participation) def fixed_random_worker_selection( - children: Iterable[Node], participation: float = 1.0, seed: int | None = None + children: Iterable[Node], rng: Generator, participation: float = 1.0 ) -> list[Node]: children = array(children) - rand_state = RandomState(seed) num_selected = max(1, int(participation * len(list(children)))) - - achildren = cast(NDArray, children) - selected_children = rand_state.choice(achildren, size=num_selected, replace=False) + selected_children = rng.choice(children, size=num_selected, replace=False) return list(selected_children) def prob_random_worker_selection( children: Iterable[Node], + rng: Generator, participation: float = 1.0, always_include_child_aggregators: bool = True, - seed: int | None = None, ) -> list[Node]: - rand_state = RandomState(seed) selected_children = [] for child in children: if child.kind is NodeKind.WORKER and always_include_child_aggregators: selected_children.append(child) - elif rand_state.uniform() <= participation: + elif rng.uniform() <= participation: selected_children.append(child) if len(selected_children) == 0: - achildren = cast(NDArray, children) - random_child = rand_state.choice(achildren) + random_child = rng.choice(array(children)) selected_children.append(random_child) return selected_children diff --git a/flight/strategies/coord.py b/flight/strategies/coord.py index 0115681..079cebd 100644 --- a/flight/strategies/coord.py +++ b/flight/strategies/coord.py @@ -9,9 +9,21 @@ NodeState: t.TypeAlias = t.Any + @t.runtime_checkable class CoordStrategy(t.Protocol): + """Template for all coordinator strategies, including those defined in Flight and those defined by Users.""" + def select_workers( self, state: NodeState, workers: t.Iterable[Node], rng: Generator ) -> t.Sequence[Node]: + """Callback that is responsible for selecting a subset of worker nodes to do local training. + + Args: + state (NodeState): The state of the current coordinator node. + children (t.Iterable[Node]): The worker nodes in the topology. + rng (Generator): The rng used for reproducibility. + + Returns: + The selected worker nodes.""" pass diff --git a/flight/strategies/impl/__init__.py b/flight/strategies/impl/__init__.py index e69de29..28bbd08 100644 --- a/flight/strategies/impl/__init__.py +++ b/flight/strategies/impl/__init__.py @@ -0,0 +1,6 @@ +from flight.strategies.impl.fedasync import FedAsync +from flight.strategies.impl.fedavg import FedAvg +from flight.strategies.impl.fedprox import FedProx +from flight.strategies.impl.fedsgd import FedSGD + +__all__ = ["FedAsync", "FedAvg", "FedProx", "FedSGD"] diff --git a/flight/strategies/impl/fedasync.py b/flight/strategies/impl/fedasync.py index d9b2e76..9af7304 100644 --- a/flight/strategies/impl/fedasync.py +++ b/flight/strategies/impl/fedasync.py @@ -3,25 +3,31 @@ import typing as t from collections import OrderedDict -from flight.strategies.base import DefaultAggrStrategy -from flight.strategies import Strategy +from flight.strategies.base import ( + DefaultAggrStrategy, + DefaultCoordStrategy, + DefaultTrainerStrategy, + DefaultWorkerStrategy, + Strategy, +) if t.TYPE_CHECKING: NodeState: t.TypeAlias = t.Any - NodeID: t.TypeAlias = t.Any - Params: t.TypeAlias = t.Any + from flight.federation.topologies.node import NodeID + from flight.strategies import Params + class FedAsyncAggr(DefaultAggrStrategy): def __init__(self, alpha: float = 0.5): assert 0.0 < alpha <= 1.0 self.alpha = alpha - + def aggregate_params( self, state: NodeState, children_states: t.Mapping[NodeID, NodeState], children_state_dicts: t.Mapping[NodeID, Params], - **kwargs + **kwargs, ) -> Params: last_updated = kwargs.get("last_updated_node", None) assert last_updated is not None @@ -34,13 +40,19 @@ def aggregate_params( for param in global_model_params: w0, w = ( global_model_params[param].detach(), - last_updated_params[param].detach() + last_updated_params[param].detach(), ) aggr_w = w0 * (1 - self.alpha) + w * self.alpha aggr_params.append((param, aggr_w)) - + return OrderedDict(aggr_params) - + + class FedAsync(Strategy): def __init__(self, alpha: float): - super().__init__(aggr_strategy=FedAsyncAggr(alpha)) \ No newline at end of file + super().__init__( + aggr_strategy=FedAsyncAggr(alpha), + coord_strategy=DefaultCoordStrategy(), + worker_strategy=DefaultWorkerStrategy(), + trainer_strategy=DefaultTrainerStrategy(), + ) diff --git a/flight/strategies/impl/fedavg.py b/flight/strategies/impl/fedavg.py index e69de29..283cafa 100644 --- a/flight/strategies/impl/fedavg.py +++ b/flight/strategies/impl/fedavg.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import typing as t + +from flight.strategies.base import ( + DefaultAggrStrategy, + DefaultTrainerStrategy, + DefaultWorkerStrategy, + Strategy, +) +from flight.strategies.commons import average_state_dicts + +from .fedsgd import FedSGDCoord + +if t.TYPE_CHECKING: + NodeState: t.TypeAlias = t.Any + from flight.federation.topologies.node import NodeID + from flight.strategies import Params + + +class FedAvgAggr(DefaultAggrStrategy): + def aggregate_params( + self, + state: NodeState, + children_states: t.Mapping[NodeID, NodeState], + children_state_dicts: t.Mapping[NodeID, Params], + **kwargs, + ) -> Params: + weights = {} + for node, child_state in children_states.items(): + weights[node] = child_state["num_data_samples"] + + state["num_data_samples"] = sum(weights.values()) + + return average_state_dicts(children_state_dicts, weights=weights) + + +class FedAvgWorker(DefaultWorkerStrategy): + def before_training( + self, state: NodeState, data: Params + ) -> tuple[NodeState, Params]: + state["num_data_samples"] = len(data) + return state, data + + +class FedAvg(Strategy): + def __init__( + self, + participation: float = 1.0, + probabilistic: bool = False, + always_include_child_aggregators: bool = False, + ): + super().__init__( + coord_strategy=FedSGDCoord( + participation, probabilistic, always_include_child_aggregators + ), + aggr_strategy=FedAvgAggr(), + worker_strategy=FedAvgWorker(), + trainer_strategy=DefaultTrainerStrategy(), + ) diff --git a/flight/strategies/impl/fedprox.py b/flight/strategies/impl/fedprox.py index e69de29..18cb34b 100644 --- a/flight/strategies/impl/fedprox.py +++ b/flight/strategies/impl/fedprox.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import typing as t + +import torch + +from flight.strategies.base import DefaultTrainerStrategy, Strategy + +from .fedavg import FedAvgWorker +from .fedsgd import FedSGDAggr, FedSGDCoord + +if t.TYPE_CHECKING: + NodeState: t.TypeAlias = t.Any + from flight.strategies import Loss + +DEVICE = "cpu" + + +class FedProxTrainer(DefaultTrainerStrategy): + def __init__(self, mu: float = 0.3): + self.mu = mu + + def before_backprop(self, state: NodeState, loss: Loss) -> Loss: + global_model = state.global_model + local_model = state.local_model + assert global_model is not None + assert local_model is not None + + global_model = global_model.to(DEVICE) + local_model = local_model.to(DEVICE) + + params = list(local_model.state_dict().values()) + params0 = list(global_model.state_dict().values()) + + proximal_diff = torch.tensor( + [ + torch.sum(torch.pow(params[i] - params0[i], 2)) + for i in range(len(params)) + ] + ) + proximal_term = torch.sum(proximal_diff) + proximal_term = proximal_term * self.mu / 2 + + proximal_term = proximal_term.to(DEVICE) + + loss += proximal_term + return loss + + +class FedProx(Strategy): + def __init__( + self, + mu: float = 0.3, + participation: float = 1.0, + probabilistic: bool = False, + always_include_child_aggregators: bool = False, + ): + super().__init__( + coord_strategy=FedSGDCoord( + participation, + probabilistic, + always_include_child_aggregators, + ), + aggr_strategy=FedSGDAggr(), + worker_strategy=FedAvgWorker(), + trainer_strategy=FedProxTrainer(mu=mu), + ) diff --git a/flight/strategies/impl/fedsgd.py b/flight/strategies/impl/fedsgd.py index e69de29..8ef9ecf 100644 --- a/flight/strategies/impl/fedsgd.py +++ b/flight/strategies/impl/fedsgd.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import typing as t + +from numpy.random import Generator + +from flight.strategies import Strategy +from flight.strategies.base import ( + DefaultAggrStrategy, + DefaultCoordStrategy, + DefaultTrainerStrategy, + DefaultWorkerStrategy, +) +from flight.strategies.commons import average_state_dicts, random_worker_selection + +if t.TYPE_CHECKING: + from flight.federation.topologies.node import Node, NodeID + from flight.strategies import Params + + NodeState: t.TypeAlias = t.Any + + +class FedSGDCoord(DefaultCoordStrategy): + def __init__( + self, + participation, + probabilistic, + always_include_child_aggregators: bool, + ): + self.participation = participation + self.probabilistic = probabilistic + self.always_include_child_aggregators = always_include_child_aggregators + + def select_worker_nodes( + self, state: NodeState, workers: t.Iterable[Node], rng: Generator | None = None + ) -> t.Iterable[Node]: + selected_workers = random_worker_selection( + workers, + participation=self.participation, + probabilistic=self.probabilistic, + always_include_child_aggregators=self.always_include_child_aggregators, + rng=rng, + ) + return selected_workers + + +class FedSGDAggr(DefaultAggrStrategy): + def aggregate_params( + self, + state: NodeState, + children_states: t.Mapping[NodeID, NodeState], + children_state_dicts: t.Mapping[NodeID, Params], + **kwargs, + ) -> Params: + return average_state_dicts(children_state_dicts, weights=None) + + +class FedSGD(Strategy): + def __init__( + self, + participation: float = 1.0, + probabilistic: bool = False, + always_include_child_aggregators: bool = True, + ): + super().__init__( + coord_strategy=FedSGDCoord( + participation, + probabilistic, + always_include_child_aggregators, + ), + aggr_strategy=FedSGDAggr(), + worker_strategy=DefaultWorkerStrategy(), + trainer_strategy=DefaultTrainerStrategy(), + ) diff --git a/flight/strategies/trainer.py b/flight/strategies/trainer.py index 5b8b71b..4352434 100644 --- a/flight/strategies/trainer.py +++ b/flight/strategies/trainer.py @@ -3,15 +3,35 @@ import typing as t if t.TYPE_CHECKING: - import torch NodeState: t.TypeAlias = t.Any - Loss: t.TypeAlias = torch.Tensor + from flight.strategies import Loss +@t.runtime_checkable class TrainerStrategy(t.Protocol): + """Template for all trainer strategies, including those defined in Flight and those defined by users.""" + def before_backprop(self, state: NodeState, loss: Loss) -> Loss: + """Callback to run before backpropagation. + + Args: + state (NodeState): State of the current node. + loss (Loss): The calculated loss + + Returns: + The loss at the end of the callback + """ pass def after_backprop(self, state: NodeState, loss: Loss) -> Loss: + """Callback to run after backpropagation. + + Args: + state (NodeState): State of the current node. + loss (Loss): The calculated loss + + Returns: + The loss at the end of the callback + """ pass diff --git a/flight/strategies/worker.py b/flight/strategies/worker.py index 9a35080..206be81 100644 --- a/flight/strategies/worker.py +++ b/flight/strategies/worker.py @@ -4,22 +4,61 @@ if t.TYPE_CHECKING: import torch + from flight.federation.jobs.result import Result NodeState: t.TypeAlias = t.Any + @t.runtime_checkable class WorkerStrategy(t.Protocol): + """Template for all aggregator strategies, including those defined in Flight and those defined by Users.""" + def start_work(self, state: NodeState) -> NodeState: + """Callback to run at the start of the current nodes 'work'. + + Args: + state (NodeState): State of the current worker node. + + Returns: + The state of the current node at the end of the callback. + """ pass def before_training(self, state: NodeState, data: t.Any) -> tuple[NodeState, t.Any]: + """Callback to run before the current nodes training. + + Args: + state (NodeState): State of the current worker node. + data (t.Any): The data related to the current worker node. + + Returns: + A tuple containing the state and data of the current worker node after the callback. + """ pass def after_training( self, state: NodeState, optimizer: torch.optim.Optimizer ) -> NodeState: + """Callback to run after the current nodes training. + + Args: + state (NodeState): State of the current worker node. + optimizer (torch.optim.Optimizer): The PyTorch optimizer used. + + Returns: + The state of the current worker node after the callback. + """ pass def end_work(self, result: Result) -> Result: + """Callback to run at the end of the current worker nodes 'work' + + Args: + result (Result): A Result object used to represent the result of the local training on the current worker node. + + Returns: + The result of the worker nodes local training. + + """ pass diff --git a/tests/strategies/impl/__init__.py b/tests/strategies/impl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/strategies/impl/test_fedasync.py b/tests/strategies/impl/test_fedasync.py new file mode 100644 index 0000000..cf37694 --- /dev/null +++ b/tests/strategies/impl/test_fedasync.py @@ -0,0 +1,68 @@ +import typing as t + +import pytest +import torch + +from flight.strategies import ( + AggrStrategy, + CoordStrategy, + DefaultStrategy, + TrainerStrategy, + WorkerStrategy, +) +from flight.strategies.base import ( + DefaultCoordStrategy, + DefaultTrainerStrategy, + DefaultWorkerStrategy, +) +from flight.strategies.impl.fedasync import FedAsync, FedAsyncAggr + +if t.TYPE_CHECKING: + NodeState: t.TypeAlias = t.Any + + +class TestValidFedAsync: + def test_class_hierarchy(self): + strategy = FedAsync(0.5) + + assert ( + isinstance(strategy.aggr_strategy, (AggrStrategy, FedAsyncAggr)) + and isinstance( + strategy.coord_strategy, (CoordStrategy, DefaultCoordStrategy) + ) + and isinstance( + strategy.trainer_strategy, (TrainerStrategy, DefaultTrainerStrategy) + ) + and isinstance( + strategy.worker_strategy, (WorkerStrategy, DefaultWorkerStrategy) + ) + ) + + def test_fedasync_aggr(self): + strategy = FedAsync(alpha=0.5) + aggr_strategy: AggrStrategy = strategy.aggr_strategy + + nodestate: NodeState = "foo" + childstates = {1: "foo1", 2: "foo2"} + children_state_dicts_pt = { + 1: { + "train/loss": torch.tensor(2.3, dtype=torch.float32), + "train/acc": torch.tensor(1.2, dtype=torch.float32), + }, + 2: { + "train/loss": torch.tensor(3.1, dtype=torch.float32), + "train/acc": torch.tensor(1.4, dtype=torch.float32), + }, + } + + # avg = aggr_strategy.aggregate_params( + # nodestate, childstates, children_state_dicts_pt, last_updated_node=1 + # ) + + assert NotImplementedError + + +class TestInvalidFedAsync: + def test_invalid_alpha(self): + with pytest.raises(AssertionError): + fedasync = FedAsync(alpha=1.1) diff --git a/tests/strategies/impl/test_fedavg.py b/tests/strategies/impl/test_fedavg.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/strategies/impl/test_fedprox.py b/tests/strategies/impl/test_fedprox.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/strategies/impl/test_fedsgd.py b/tests/strategies/impl/test_fedsgd.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/strategies/test_aggr.py b/tests/strategies/test_aggr.py index 9eb956d..82f12c2 100644 --- a/tests/strategies/test_aggr.py +++ b/tests/strategies/test_aggr.py @@ -1,8 +1,12 @@ +import typing as t + +import torch + from flight.strategies import AggrStrategy from flight.strategies.base import DefaultAggrStrategy -import tensorflow as tf -import torch +if t.TYPE_CHECKING: + NodeState: t.TypeAlias = t.Any def test_instance(): @@ -14,27 +18,20 @@ def test_instance(): def test_aggr_aggregate_params(): default_aggr = DefaultAggrStrategy() - state = "foo" + state: NodeState = "foo" children = {1: "foo1", 2: "foo2"} - children_state_dicts = { + children_state_dicts_pt = { 1: { - "train/loss": tf.convert_to_tensor(2.3, dtype=tf.float32), - "train/acc": tf.convert_to_tensor(1.2, dtype=tf.float32), + "train/loss": torch.tensor(2.3, dtype=torch.float32), + "train/acc": torch.tensor(1.2, dtype=torch.float32), }, 2: { - "train/loss": tf.convert_to_tensor(3.1, dtype=tf.float32), - "train/acc": tf.convert_to_tensor(1.4, dtype=tf.float32), + "train/loss": torch.tensor(3.1, dtype=torch.float32), + "train/acc": torch.tensor(1.4, dtype=torch.float32), }, } - children_state_dicts_pt = { - key: { - sub_key: torch.tensor(value.numpy()) for sub_key, value in sub_dict.items() - } - for key, sub_dict in children_state_dicts.items() - } - avg = default_aggr.aggregate_params(state, children, children_state_dicts_pt) assert isinstance(avg, dict) @@ -47,5 +44,4 @@ def test_aggr_aggregate_params(): epsilon = 1e-6 for key, value in avg.items(): expected = expected_avg[key] - - assert abs(expected - value) < epsilon + assert abs(expected - value.item()) < epsilon diff --git a/tests/strategies/test_coord.py b/tests/strategies/test_coord.py index 0380d19..75ffe8f 100644 --- a/tests/strategies/test_coord.py +++ b/tests/strategies/test_coord.py @@ -1,8 +1,71 @@ +import pytest +from numpy.random import default_rng + +from flight.federation.topologies.node import Node, NodeKind from flight.strategies import CoordStrategy from flight.strategies.base import DefaultCoordStrategy +from flight.strategies.commons.worker_selection import random_worker_selection + def test_instance(): default_coord = DefaultCoordStrategy() assert isinstance(default_coord, CoordStrategy) + +def test_worker_selection(): + gen = default_rng() + + worker1 = Node(idx=1, kind=NodeKind.WORKER) + worker2 = Node(idx=2, kind=NodeKind.WORKER) + workers = [worker1, worker2] + # fixed random + fixed_random = random_worker_selection( + workers, + participation=1, + probabilistic=False, + always_include_child_aggregators=True, + rng=gen, + ) + # prob random + prob_random = random_worker_selection( + workers, + participation=1, + probabilistic=True, + always_include_child_aggregators=True, + rng=gen, + ) + + for worker in workers: + assert worker in fixed_random and worker in prob_random + + +class TestInvalidFixedSelection: + def test_fixed_random(self): + gen = default_rng() + + workers = [Node(idx=i, kind=NodeKind.WORKER) for i in range(1, 6)] + + with pytest.raises(ValueError): + fixed_random = random_worker_selection( + workers, + participation=2, + probabilistic=False, + always_include_child_aggregators=True, + rng=gen, + ) + + def test_fixed_random_mix(self): + gen = default_rng() + + children = [Node(idx=i, kind=NodeKind.AGGR) for i in range(1, 3)] + children.append(Node(idx=3, kind=NodeKind.WORKER)) + + with pytest.raises(ValueError): + fixed_random = random_worker_selection( + children, + participation=2, + probabilistic=False, + always_include_child_aggregators=True, + rng=gen, + ) diff --git a/tests/strategies/test_worker.py b/tests/strategies/test_worker.py new file mode 100644 index 0000000..e72ef4f --- /dev/null +++ b/tests/strategies/test_worker.py @@ -0,0 +1,8 @@ +from flight.strategies import WorkerStrategy +from flight.strategies.base import DefaultWorkerStrategy + + +def test_instance(): + default_worker = DefaultWorkerStrategy() + + assert isinstance(default_worker, WorkerStrategy) From 5da98adfe96b3c40094d2634d3509fc248aef470 Mon Sep 17 00:00:00 2001 From: loganalt9 Date: Thu, 18 Jul 2024 12:04:41 -0400 Subject: [PATCH 03/11] Strategies refactor and partial tests --- flight/federation/jobs/result.py | 9 ++- flight/federation/topologies/io.py | 1 + flight/strategies/__init__.py | 7 ++- flight/strategies/aggr.py | 5 +- flight/strategies/base.py | 30 ++++++--- flight/strategies/coord.py | 3 +- flight/strategies/impl/fedasync.py | 3 +- flight/strategies/impl/fedavg.py | 3 +- flight/strategies/impl/fedprox.py | 3 +- flight/strategies/impl/fedsgd.py | 6 +- flight/strategies/trainer.py | 3 +- flight/strategies/worker.py | 3 +- tests/strategies/environment.py | 11 ++++ tests/strategies/impl/test_fedasync.py | 18 +++--- tests/strategies/impl/test_fedavg.py | 85 ++++++++++++++++++++++++++ tests/strategies/impl/test_fedprox.py | 20 ++++++ tests/strategies/impl/test_fedsgd.py | 71 +++++++++++++++++++++ tests/strategies/test_aggr.py | 5 +- tests/strategies/test_coord.py | 21 +++---- tests/strategies/test_worker.py | 24 ++++++++ 20 files changed, 268 insertions(+), 63 deletions(-) create mode 100644 tests/strategies/environment.py diff --git a/flight/federation/jobs/result.py b/flight/federation/jobs/result.py index fb0050e..dda248e 100644 --- a/flight/federation/jobs/result.py +++ b/flight/federation/jobs/result.py @@ -4,14 +4,13 @@ from pydantic.dataclasses import dataclass from flight.learning.module import RecordList -from flight.strategies.aggr import Params +from flight.strategies import Params -if t.TYPE_CHECKING: - NodeID: t.TypeAlias = t.Hashable - NodeState: t.TypeAlias = tuple +NodeID: t.TypeAlias = t.Hashable +NodeState: t.TypeAlias = tuple -@dataclass +@dataclass(config={"arbitrary_types_allowed": True}) class Result: state: NodeState node_idx: NodeID diff --git a/flight/federation/topologies/io.py b/flight/federation/topologies/io.py index 7cb432a..935ad64 100644 --- a/flight/federation/topologies/io.py +++ b/flight/federation/topologies/io.py @@ -9,6 +9,7 @@ wish to use `from_yaml()` to create a Topology with a YAML file, we encourage the use the `Topology.from_yaml()` method instead. """ + from __future__ import annotations import json diff --git a/flight/strategies/__init__.py b/flight/strategies/__init__.py index 8dac982..9825f56 100644 --- a/flight/strategies/__init__.py +++ b/flight/strategies/__init__.py @@ -2,15 +2,16 @@ import torch -Loss: t.TypeAlias = torch.tensor -Params: t.TypeAlias = dict[str, torch.Tensor] - from flight.strategies.aggr import AggrStrategy from flight.strategies.base import DefaultStrategy, Strategy from flight.strategies.coord import CoordStrategy from flight.strategies.trainer import TrainerStrategy from flight.strategies.worker import WorkerStrategy +Loss: t.TypeAlias = torch.Tensor +Params: t.TypeAlias = dict[str, torch.Tensor] +NodeState: t.TypeAlias = t.Any + def load_strategy(strategy_name: str, **kwargs) -> Strategy: assert isinstance(strategy_name, str), "`strategy_name` must be a string." diff --git a/flight/strategies/aggr.py b/flight/strategies/aggr.py index 51c7217..6495295 100644 --- a/flight/strategies/aggr.py +++ b/flight/strategies/aggr.py @@ -5,11 +5,8 @@ if t.TYPE_CHECKING: import torch - from flight.strategies import Params - - NodeState: t.TypeAlias = t.Any - from flight.federation.topologies.node import NodeID + from flight.strategies import NodeState, Params @t.runtime_checkable diff --git a/flight/strategies/base.py b/flight/strategies/base.py index 2cbbe3c..a6207e7 100644 --- a/flight/strategies/base.py +++ b/flight/strategies/base.py @@ -12,24 +12,23 @@ from flight.strategies.worker import WorkerStrategy StrategyType: t.TypeAlias = ( - t.Any -) # WorkerStrategy | AggrStrategy | CoordStrategy | TrainerStrategy + WorkerStrategy | AggrStrategy | CoordStrategy | TrainerStrategy +) if t.TYPE_CHECKING: import torch from numpy.random import Generator - NodeState: t.TypeAlias = t.Any from flight.federation.jobs.result import Result from flight.federation.topologies.node import Node, NodeID - from flight.strategies import Loss, Params + from flight.strategies import Loss, NodeState, Params class DefaultCoordStrategy: def select_workers( - self, state: NodeState, children: t.Iterable[Node], rng: Generator + self, state: NodeState, workers: t.Iterable[Node], rng: Generator ) -> t.Sequence[Node]: - return list(children) + return list(workers) class DefaultAggrStrategy: @@ -75,7 +74,9 @@ def after_backprop(self, state: NodeState, loss: Loss) -> Loss: return loss -@pyd.dataclasses.dataclass(frozen=True, repr=False) +@pyd.dataclasses.dataclass( + frozen=True, repr=False, config={"arbitrary_types_allowed": True} +) class Strategy: """ A 'Strategy' implementation is comprised of the four different type of implementations of strategies @@ -102,8 +103,16 @@ def __iter__(self) -> t.Iterator[tuple[str, t.Any]]: def __repr__(self) -> str: return str(self) - # @functools.cached_property - def __str__(self) -> str: + @functools.cached_property + def _description(self) -> str: + """A utility function for generating the string for `__str__`. + + This is written to avoid the `mypy` issue: + "Signature of '__str__' incompatible with supertype 'object'". + + Returns: + The string representation of the a Strategy instance. + """ name = self.__class__.__name__ inner = ", ".join( [ @@ -114,6 +123,9 @@ def __str__(self) -> str: ) return f"{name}({inner})" + def __str__(self) -> str: + return self._description + class DefaultStrategy(Strategy): def __init__(self) -> None: diff --git a/flight/strategies/coord.py b/flight/strategies/coord.py index 079cebd..3ff8ddc 100644 --- a/flight/strategies/coord.py +++ b/flight/strategies/coord.py @@ -6,8 +6,7 @@ from numpy.random import Generator from flight.federation.topologies.node import Node - - NodeState: t.TypeAlias = t.Any + from flight.strategies import NodeState @t.runtime_checkable diff --git a/flight/strategies/impl/fedasync.py b/flight/strategies/impl/fedasync.py index 9af7304..c8a4c66 100644 --- a/flight/strategies/impl/fedasync.py +++ b/flight/strategies/impl/fedasync.py @@ -12,9 +12,8 @@ ) if t.TYPE_CHECKING: - NodeState: t.TypeAlias = t.Any from flight.federation.topologies.node import NodeID - from flight.strategies import Params + from flight.strategies import NodeState, Params class FedAsyncAggr(DefaultAggrStrategy): diff --git a/flight/strategies/impl/fedavg.py b/flight/strategies/impl/fedavg.py index 283cafa..4178530 100644 --- a/flight/strategies/impl/fedavg.py +++ b/flight/strategies/impl/fedavg.py @@ -13,9 +13,8 @@ from .fedsgd import FedSGDCoord if t.TYPE_CHECKING: - NodeState: t.TypeAlias = t.Any from flight.federation.topologies.node import NodeID - from flight.strategies import Params + from flight.strategies import NodeState, Params class FedAvgAggr(DefaultAggrStrategy): diff --git a/flight/strategies/impl/fedprox.py b/flight/strategies/impl/fedprox.py index 18cb34b..cecd927 100644 --- a/flight/strategies/impl/fedprox.py +++ b/flight/strategies/impl/fedprox.py @@ -10,8 +10,7 @@ from .fedsgd import FedSGDAggr, FedSGDCoord if t.TYPE_CHECKING: - NodeState: t.TypeAlias = t.Any - from flight.strategies import Loss + from flight.strategies import Loss, NodeState DEVICE = "cpu" diff --git a/flight/strategies/impl/fedsgd.py b/flight/strategies/impl/fedsgd.py index 8ef9ecf..e4fd565 100644 --- a/flight/strategies/impl/fedsgd.py +++ b/flight/strategies/impl/fedsgd.py @@ -15,9 +15,7 @@ if t.TYPE_CHECKING: from flight.federation.topologies.node import Node, NodeID - from flight.strategies import Params - - NodeState: t.TypeAlias = t.Any + from flight.strategies import NodeState, Params class FedSGDCoord(DefaultCoordStrategy): @@ -33,7 +31,7 @@ def __init__( def select_worker_nodes( self, state: NodeState, workers: t.Iterable[Node], rng: Generator | None = None - ) -> t.Iterable[Node]: + ) -> t.Sequence[Node]: selected_workers = random_worker_selection( workers, participation=self.participation, diff --git a/flight/strategies/trainer.py b/flight/strategies/trainer.py index 4352434..236b571 100644 --- a/flight/strategies/trainer.py +++ b/flight/strategies/trainer.py @@ -4,8 +4,7 @@ if t.TYPE_CHECKING: - NodeState: t.TypeAlias = t.Any - from flight.strategies import Loss + from flight.strategies import Loss, NodeState @t.runtime_checkable diff --git a/flight/strategies/worker.py b/flight/strategies/worker.py index 206be81..4d0f2ac 100644 --- a/flight/strategies/worker.py +++ b/flight/strategies/worker.py @@ -6,8 +6,7 @@ import torch from flight.federation.jobs.result import Result - - NodeState: t.TypeAlias = t.Any + from flight.strategies import NodeState @t.runtime_checkable diff --git a/tests/strategies/environment.py b/tests/strategies/environment.py new file mode 100644 index 0000000..7409e43 --- /dev/null +++ b/tests/strategies/environment.py @@ -0,0 +1,11 @@ +from flight.federation.topologies.node import Node, NodeKind + + +def create_children(numWorkers: int, numAggr: int = 0) -> list[Node]: + aggr = [Node(idx=i, kind=NodeKind.AGGR) for i in range(1, numAggr + 1)] + workers = [ + Node(idx=i + numAggr, kind=NodeKind.WORKER) for i in range(1, numWorkers + 1) + ] + print(workers + aggr) + print(aggr) + return workers + aggr diff --git a/tests/strategies/impl/test_fedasync.py b/tests/strategies/impl/test_fedasync.py index cf37694..5069a99 100644 --- a/tests/strategies/impl/test_fedasync.py +++ b/tests/strategies/impl/test_fedasync.py @@ -7,6 +7,7 @@ AggrStrategy, CoordStrategy, DefaultStrategy, + NodeState, TrainerStrategy, WorkerStrategy, ) @@ -17,32 +18,29 @@ ) from flight.strategies.impl.fedasync import FedAsync, FedAsyncAggr -if t.TYPE_CHECKING: - NodeState: t.TypeAlias = t.Any - class TestValidFedAsync: def test_class_hierarchy(self): - strategy = FedAsync(0.5) + fedasync = FedAsync(0.5) assert ( - isinstance(strategy.aggr_strategy, (AggrStrategy, FedAsyncAggr)) + isinstance(fedasync.aggr_strategy, (AggrStrategy, FedAsyncAggr)) and isinstance( - strategy.coord_strategy, (CoordStrategy, DefaultCoordStrategy) + fedasync.coord_strategy, (CoordStrategy, DefaultCoordStrategy) ) and isinstance( - strategy.trainer_strategy, (TrainerStrategy, DefaultTrainerStrategy) + fedasync.trainer_strategy, (TrainerStrategy, DefaultTrainerStrategy) ) and isinstance( - strategy.worker_strategy, (WorkerStrategy, DefaultWorkerStrategy) + fedasync.worker_strategy, (WorkerStrategy, DefaultWorkerStrategy) ) ) def test_fedasync_aggr(self): - strategy = FedAsync(alpha=0.5) + strategy = FedAsync(0.5) aggr_strategy: AggrStrategy = strategy.aggr_strategy - nodestate: NodeState = "foo" + nodestate: NodeState = {} childstates = {1: "foo1", 2: "foo2"} children_state_dicts_pt = { 1: { diff --git a/tests/strategies/impl/test_fedavg.py b/tests/strategies/impl/test_fedavg.py index e69de29..f5bfcd9 100644 --- a/tests/strategies/impl/test_fedavg.py +++ b/tests/strategies/impl/test_fedavg.py @@ -0,0 +1,85 @@ +import typing as t + +import torch + +from flight.strategies import ( + AggrStrategy, + CoordStrategy, + Params, + TrainerStrategy, + WorkerStrategy, +) +from flight.strategies.base import DefaultTrainerStrategy +from flight.strategies.impl.fedavg import FedAvg, FedAvgAggr, FedAvgWorker +from flight.strategies.impl.fedsgd import FedSGDCoord + +if t.TYPE_CHECKING: + NodeState: t.TypeAlias = t.Any + + +class TestValidFedAvg: + def test_fedavg_class_hierarchy(self): + fedavg = FedAvg() + + assert ( + isinstance(fedavg.aggr_strategy, (AggrStrategy, FedAvgAggr)) + and isinstance(fedavg.coord_strategy, (CoordStrategy, FedSGDCoord)) + and isinstance( + fedavg.trainer_strategy, (TrainerStrategy, DefaultTrainerStrategy) + ) + and isinstance(fedavg.worker_strategy, (WorkerStrategy, FedAvgWorker)) + ) + + def test_fedavg_aggr(self): + fedavg = FedAvg() + aggregatorStrat: AggrStrategy = fedavg.aggr_strategy + nodestate: NodeState = {} + childstates = { + 1: {"num_data_samples": 1, "other_data": "foo"}, + 2: {"num_data_samples": 1, "other_data": "foo"}, + } + children_state_dicts_pt = { + 1: { + "train/loss": torch.tensor(2.3, dtype=torch.float32), + "train/acc": torch.tensor(1.2, dtype=torch.float32), + }, + 2: { + "train/loss": torch.tensor(3.1, dtype=torch.float32), + "train/acc": torch.tensor(1.4, dtype=torch.float32), + }, + } + + aggregated = aggregatorStrat.aggregate_params( + nodestate, childstates, children_state_dicts_pt + ) + + assert isinstance(aggregated, dict) + + expected_avg = { + "train/loss": 2.7, + "train/acc": 1.3, + } + + epsilon = 1e-6 + for key, value in aggregated.items(): + expected = expected_avg[key] + assert abs(expected - value.item()) < epsilon + + def test_fedavg_worker(self): + fedavg = FedAvg() + + workerStrat: WorkerStrategy = fedavg.worker_strategy + + nodestate_before: NodeState = {"State:": "Training preperation"} + data_before: Params = { + "train/loss1": torch.tensor(0.35, dtype=torch.float32), + "train/loss2": torch.tensor(0.5, dtype=torch.float32), + "train/loss3": torch.tensor(0.23, dtype=torch.float32), + } + + nodestate_after, data_after = workerStrat.before_training( + nodestate_before, data_before + ) + + assert nodestate_after["num_data_samples"] == len(data_before) + assert data_before == data_after diff --git a/tests/strategies/impl/test_fedprox.py b/tests/strategies/impl/test_fedprox.py index e69de29..ecc080e 100644 --- a/tests/strategies/impl/test_fedprox.py +++ b/tests/strategies/impl/test_fedprox.py @@ -0,0 +1,20 @@ +from flight.strategies import ( + AggrStrategy, + CoordStrategy, + TrainerStrategy, + WorkerStrategy, +) +from flight.strategies.impl.fedavg import FedAvgWorker +from flight.strategies.impl.fedprox import FedProx, FedProxTrainer +from flight.strategies.impl.fedsgd import FedSGDAggr, FedSGDCoord + + +class TestValidFedProx: + def test_fedprox_class_hierarchy(self): + fedprox = FedProx(0.3, 1, False, False) + assert ( + isinstance(fedprox.aggr_strategy, (AggrStrategy, FedSGDAggr)) + and isinstance(fedprox.coord_strategy, (CoordStrategy, FedSGDCoord)) + and isinstance(fedprox.trainer_strategy, (TrainerStrategy, FedProxTrainer)) + and isinstance(fedprox.worker_strategy, (WorkerStrategy, FedAvgWorker)) + ) diff --git a/tests/strategies/impl/test_fedsgd.py b/tests/strategies/impl/test_fedsgd.py index e69de29..7759ba0 100644 --- a/tests/strategies/impl/test_fedsgd.py +++ b/tests/strategies/impl/test_fedsgd.py @@ -0,0 +1,71 @@ +import torch +from numpy.random import default_rng + +from flight.strategies import ( + AggrStrategy, + CoordStrategy, + TrainerStrategy, + WorkerStrategy, +) +from flight.strategies.base import DefaultTrainerStrategy, DefaultWorkerStrategy +from flight.strategies.impl.fedsgd import FedSGD, FedSGDAggr, FedSGDCoord +from tests.strategies.environment import create_children + + +class TestValidFedSGD: + def test_fedsgd_class_hierarchy(self): + fedsgd = FedSGD(1, False, True) + + assert ( + isinstance(fedsgd.aggr_strategy, (AggrStrategy, FedSGDAggr)) + and isinstance(fedsgd.coord_strategy, (CoordStrategy, FedSGDCoord)) + and isinstance( + fedsgd.trainer_strategy, (TrainerStrategy, DefaultTrainerStrategy) + ) + and isinstance( + fedsgd.worker_strategy, (WorkerStrategy, DefaultWorkerStrategy) + ) + ) + + def test_default_fedsgd_coord(self): + fedsgd = FedSGD(1, False, True) + coordStrat: CoordStrategy = fedsgd.coord_strategy + gen = default_rng() + workers = create_children(numWorkers=10) + + selected = coordStrat.select_workers("foo", workers, gen) + + for worker in workers: + assert worker in selected + + def test_fedsgd_aggr(self): + fedsgd = FedSGD(1, False, True) + aggrStrat: AggrStrategy = fedsgd.aggr_strategy + + state = "foo" + children = {1: "foo1", 2: "foo2"} + + children_state_dicts_pt = { + 1: { + "train/loss": torch.tensor(2.3, dtype=torch.float32), + "train/acc": torch.tensor(1.2, dtype=torch.float32), + }, + 2: { + "train/loss": torch.tensor(3.1, dtype=torch.float32), + "train/acc": torch.tensor(1.4, dtype=torch.float32), + }, + } + + avg = aggrStrat.aggregate_params(state, children, children_state_dicts_pt) + + assert isinstance(avg, dict) + + expected_avg = { + "train/loss": 2.7, + "train/acc": 1.3, + } + + epsilon = 1e-6 + for key, value in avg.items(): + expected = expected_avg[key] + assert abs(expected - value.item()) < epsilon diff --git a/tests/strategies/test_aggr.py b/tests/strategies/test_aggr.py index 82f12c2..86ec8d6 100644 --- a/tests/strategies/test_aggr.py +++ b/tests/strategies/test_aggr.py @@ -2,12 +2,9 @@ import torch -from flight.strategies import AggrStrategy +from flight.strategies import AggrStrategy, NodeState from flight.strategies.base import DefaultAggrStrategy -if t.TYPE_CHECKING: - NodeState: t.TypeAlias = t.Any - def test_instance(): default_aggr = DefaultAggrStrategy() diff --git a/tests/strategies/test_coord.py b/tests/strategies/test_coord.py index 75ffe8f..5aec9f1 100644 --- a/tests/strategies/test_coord.py +++ b/tests/strategies/test_coord.py @@ -1,10 +1,10 @@ import pytest from numpy.random import default_rng -from flight.federation.topologies.node import Node, NodeKind from flight.strategies import CoordStrategy from flight.strategies.base import DefaultCoordStrategy from flight.strategies.commons.worker_selection import random_worker_selection +from tests.strategies.environment import create_children def test_instance(): @@ -16,12 +16,10 @@ def test_instance(): def test_worker_selection(): gen = default_rng() - worker1 = Node(idx=1, kind=NodeKind.WORKER) - worker2 = Node(idx=2, kind=NodeKind.WORKER) - workers = [worker1, worker2] + children = create_children(numWorkers=2) # fixed random fixed_random = random_worker_selection( - workers, + children, participation=1, probabilistic=False, always_include_child_aggregators=True, @@ -29,26 +27,26 @@ def test_worker_selection(): ) # prob random prob_random = random_worker_selection( - workers, + children, participation=1, probabilistic=True, always_include_child_aggregators=True, rng=gen, ) - for worker in workers: - assert worker in fixed_random and worker in prob_random + for child in children: + assert child in fixed_random and child in prob_random class TestInvalidFixedSelection: def test_fixed_random(self): gen = default_rng() - workers = [Node(idx=i, kind=NodeKind.WORKER) for i in range(1, 6)] + children = create_children(numWorkers=5) with pytest.raises(ValueError): fixed_random = random_worker_selection( - workers, + children, participation=2, probabilistic=False, always_include_child_aggregators=True, @@ -58,8 +56,7 @@ def test_fixed_random(self): def test_fixed_random_mix(self): gen = default_rng() - children = [Node(idx=i, kind=NodeKind.AGGR) for i in range(1, 3)] - children.append(Node(idx=3, kind=NodeKind.WORKER)) + children = create_children(numWorkers=1, numAggr=2) with pytest.raises(ValueError): fixed_random = random_worker_selection( diff --git a/tests/strategies/test_worker.py b/tests/strategies/test_worker.py index e72ef4f..a38e04b 100644 --- a/tests/strategies/test_worker.py +++ b/tests/strategies/test_worker.py @@ -1,8 +1,32 @@ +import typing as t + +import torch + +from flight.federation.jobs.result import Result from flight.strategies import WorkerStrategy from flight.strategies.base import DefaultWorkerStrategy +if t.TYPE_CHECKING: + NodeState: t.TypeAlias = t.Any + def test_instance(): default_worker = DefaultWorkerStrategy() assert isinstance(default_worker, WorkerStrategy) + + +def test_default_methods(): + default_worker = DefaultWorkerStrategy() + workerstate = ("NodeID 1", "num_data_samples 2") + data = { + "train/acc": torch.tensor(0.3, dtype=torch.float32), + "train/loss": torch.tensor(0.7, dtype=torch.float32), + } + # result = Result(workerstate, 1, data, [], {}) + # optimizer = torch.optim.Optimizer(data.values(), {}) + + # assert default_worker.start_work(workerstate) == workerstate + # assert default_worker.before_training(workerstate, data) == (workerstate, data) + # assert default_worker.after_training(workerstate, optimizer) == workerstate + # assert default_worker.end_work(result) == result From 6695eb3070183fd35d3cdccedc48869b4e54602b Mon Sep 17 00:00:00 2001 From: loganalt9 Date: Thu, 18 Jul 2024 17:34:53 -0400 Subject: [PATCH 04/11] Resolved issues + added documentation --- flight/federation/jobs/result.py | 1 + flight/strategies/__init__.py | 11 +++ flight/strategies/aggr.py | 4 +- flight/strategies/base.py | 94 ++++++++++++++++++- flight/strategies/commons/averaging.py | 9 ++ flight/strategies/commons/worker_selection.py | 38 +++++++- flight/strategies/impl/fedasync.py | 24 +++++ flight/strategies/impl/fedavg.py | 39 ++++++++ flight/strategies/impl/fedprox.py | 22 +++++ flight/strategies/impl/fedsgd.py | 40 ++++++++ tests/strategies/environment.py | 11 ++- tests/strategies/impl/test_fedasync.py | 23 ++--- tests/strategies/impl/test_fedavg.py | 15 +-- tests/strategies/impl/test_fedprox.py | 11 +-- tests/strategies/impl/test_fedsgd.py | 19 ++-- tests/strategies/test_aggr.py | 2 + tests/strategies/test_coord.py | 40 ++++---- tests/strategies/test_worker.py | 17 +--- 18 files changed, 340 insertions(+), 80 deletions(-) diff --git a/flight/federation/jobs/result.py b/flight/federation/jobs/result.py index dda248e..cdc05b2 100644 --- a/flight/federation/jobs/result.py +++ b/flight/federation/jobs/result.py @@ -10,6 +10,7 @@ NodeState: t.TypeAlias = tuple +# TODO: Remove config when all type definitions have been resolvedß @dataclass(config={"arbitrary_types_allowed": True}) class Result: state: NodeState diff --git a/flight/strategies/__init__.py b/flight/strategies/__init__.py index 9825f56..ef5f39f 100644 --- a/flight/strategies/__init__.py +++ b/flight/strategies/__init__.py @@ -14,6 +14,17 @@ def load_strategy(strategy_name: str, **kwargs) -> Strategy: + """Function used to grab the users preferred 'Strategy'. + + Args: + strategy_name (str): The name of the 'Strategy' to be grabbed. + + Raises: + ValueError: If an unknown 'Strategy' type is passed through. + + Returns: + Strategy: The selected 'Strategy' type. + """ assert isinstance(strategy_name, str), "`strategy_name` must be a string." match strategy_name.lower(): case "default": diff --git a/flight/strategies/aggr.py b/flight/strategies/aggr.py index 6495295..e59f6fe 100644 --- a/flight/strategies/aggr.py +++ b/flight/strategies/aggr.py @@ -3,8 +3,6 @@ import typing as t if t.TYPE_CHECKING: - import torch - from flight.federation.topologies.node import NodeID from flight.strategies import NodeState, Params @@ -33,7 +31,7 @@ def aggregate_params( **kwargs: Keyword arguments provided by users. Returns: - The aggregated parameters to update the model at the current aggregator. + Params: The aggregated parameters to update the model at the current aggregator. """ pass diff --git a/flight/strategies/base.py b/flight/strategies/base.py index a6207e7..e834575 100644 --- a/flight/strategies/base.py +++ b/flight/strategies/base.py @@ -25,13 +25,27 @@ class DefaultCoordStrategy: + """Default implementation of the strategy for a coordinator.""" + def select_workers( self, state: NodeState, workers: t.Iterable[Node], rng: Generator ) -> t.Sequence[Node]: + """Method used for the selection of workers. + + Args: + state (NodeState): The state of the coordinator node. + workers (t.Iterable[Node]): Iterable object containing all of the worker nodes. + rng (Generator): RNG object used for randomness. + + Returns: + t.Sequence[Node]: The selected workers. + """ return list(workers) class DefaultAggrStrategy: + """Default implementation of the strategy for an aggregator.""" + def start_round(self): pass @@ -42,6 +56,17 @@ def aggregate_params( children_state_dicts: t.Mapping[NodeID, Params], **kwargs, ) -> Params: + """Callback that handles the model parameter aggregation step. + + Args: + state (NodeState): The state of the current aggregator node. + children_states (t.Mapping[NodeID, NodeState]): A mapping of the current aggregator node's children and their respective states. + children_state_dicts (t.Mapping[NodeID, Parmas]): The model parameters of the models to each respective child node. + **kwargs: Keyword arguments provided by users. + + Returns: + Params: The aggregated values. + """ return average_state_dicts(children_state_dicts, weights=None) def end_round(self): @@ -49,31 +74,88 @@ def end_round(self): class DefaultWorkerStrategy: + """Default implementation of the strategy for a worker""" + def start_work(self, state: NodeState) -> NodeState: + """Callback to be ran and the start of the current worker nodes work. + + Args: + state (NodeState): The state of the current worker node. + + Returns: + NodeState: The state of the current worker node at the end of the callback. + """ return state def before_training( self, state: NodeState, data: Params ) -> tuple[NodeState, Params]: + """Callback to be ran before training. + + Args: + state (NodeState): The state of the current worker node. + data (Params): The data associated with the current worker node. + + Returns: + tuple[NodeState, Params]: A tuple containing the state and data of the worker node at the end of the callback. + """ return state, data def after_training( self, state: NodeState, optimizer: torch.optim.Optimizer ) -> NodeState: + """Callback to be ran after training. + + Args: + state (NodeState): The state of the current worker node. + optimizer (torch.optim.Optimizer): The PyTorch optimier to be used. + + Returns: + NodeState: The state of the worker node at the end of the callback. + """ return state def end_work(self, result: Result) -> Result: + """Callback to be ran at the end of the work. + + Args: + result (Result): A Result object used to represent the result of the local training on the current worker node. + + Returns: + Result: The result of the worker nodes local training. + """ return result class DefaultTrainerStrategy: + """Default implementation of a strategy for the trainer.""" + def before_backprop(self, state: NodeState, loss: Loss) -> Loss: + """Callback to run before backpropagation. + + Args: + state (NodeState): State of the current node. + loss (Loss): The calculated loss + + Returns: + The loss at the end of the callback + """ return loss def after_backprop(self, state: NodeState, loss: Loss) -> Loss: + """Callback to run after backpropagation. + + Args: + state (NodeState): State of the current node. + loss (Loss): The calculated loss + + Returns: + The loss at the end of the callback + """ return loss +# TODO: Remove config when all type definitions have been resolved @pyd.dataclasses.dataclass( frozen=True, repr=False, config={"arbitrary_types_allowed": True} ) @@ -84,15 +166,15 @@ class Strategy: """ """Implementation of the specific callbacks for the coordinator node.""" - coord_strategy: StrategyType = pyd.Field() + coord_strategy: CoordStrategy = pyd.Field() """Implementation of the specific callbacks for the aggregator node(s).""" - aggr_strategy: StrategyType = pyd.Field() + aggr_strategy: AggrStrategy = pyd.Field() """Implementation of the specific callbacks for the worker node(s).""" - worker_strategy: StrategyType = pyd.Field() + worker_strategy: WorkerStrategy = pyd.Field() """Implementation of callbacks specific to the execution of the training loop on the worker node(s).""" - trainer_strategy: StrategyType = pyd.Field() + trainer_strategy: TrainerStrategy = pyd.Field() - def __iter__(self) -> t.Iterator[tuple[str, t.Any]]: + def __iter__(self) -> t.Iterator[tuple[str, StrategyType]]: yield from ( ("coord_strategy", self.coord_strategy), ("aggr_strategy", self.aggr_strategy), @@ -128,6 +210,8 @@ def __str__(self) -> str: class DefaultStrategy(Strategy): + """Implementation of a strategy that uses the default strategy types for each node type.""" + def __init__(self) -> None: super().__init__( coord_strategy=DefaultCoordStrategy(), diff --git a/flight/strategies/commons/averaging.py b/flight/strategies/commons/averaging.py index ef1eca0..6a2edf5 100644 --- a/flight/strategies/commons/averaging.py +++ b/flight/strategies/commons/averaging.py @@ -15,6 +15,15 @@ def average_state_dicts( state_dicts: Mapping[NodeID, Params], weights: Mapping[NodeID, float] | None = None ) -> Params: + """Helper function used by aggregator nodes for averaging the passed node state dictionary. + + Args: + state_dicts (Mapping[NodeID, Params]): A dictionary object mapping nodes to their respective states. + weights (Mapping[NodeID, float] | None, optional): Optional dictionary that maps each node to its contribution factor. Defaults to None. + + Returns: + Params: The averaged states. + """ num_nodes = len(state_dicts) if weights is not None: diff --git a/flight/strategies/commons/worker_selection.py b/flight/strategies/commons/worker_selection.py index eb0f331..26027f6 100644 --- a/flight/strategies/commons/worker_selection.py +++ b/flight/strategies/commons/worker_selection.py @@ -1,9 +1,7 @@ from collections.abc import Iterable -from typing import cast from numpy import array -from numpy.random import Generator, RandomState, default_rng -from numpy.typing import NDArray +from numpy.random import Generator, default_rng from flight.federation.topologies.node import Node, NodeKind @@ -15,6 +13,18 @@ def random_worker_selection( always_include_child_aggregators: bool = True, rng: Generator | None = None, ) -> list[Node]: + """General call for worker selection that will then choose from probabilistic or fixed selection. + + Args: + children (Iterable[Node]): Children to be evaluated for worker selection. + participation (float, optional): Controls the level of participation each node contributes. Defaults to 1.0. + probabilistic (bool, optional): Decider for whether probabilistic (True), or fixed (False) selection will be used. Defaults to False. + always_include_child_aggregators (bool, optional): In probabilistic selection, ensures whether or not all worker nodes are included. Defaults to True. + rng (Generator | None, optional): RNG object used for randomness, numpy.random.default_rng will be used if None. Defaults to None. + + Returns: + list[Node]: The selected worker nodes. + """ if rng is None: rng = default_rng() if probabilistic: @@ -27,6 +37,16 @@ def random_worker_selection( def fixed_random_worker_selection( children: Iterable[Node], rng: Generator, participation: float = 1.0 ) -> list[Node]: + """The worker selection used when probalistic is false. This worker selection is entirely random based on 'rng' + + Args: + children (Iterable[Node]): Children to be evaluated for worker selection. + rng (Generator): RNG object used for randomness. Defaults to numpy.random.default_rand. + participation (float, optional): Controls the level of participation each node contributes. Defaults to 1.0. + + Returns: + list[Node]: The selected worker nodes. + """ children = array(children) num_selected = max(1, int(participation * len(list(children)))) selected_children = rng.choice(children, size=num_selected, replace=False) @@ -39,6 +59,18 @@ def prob_random_worker_selection( participation: float = 1.0, always_include_child_aggregators: bool = True, ) -> list[Node]: + """The worker selection used when probalistic is true. This worker selection is probalistic and therefore in most cases + will select workers in order. + + Args: + children (Iterable[Node]): Children to be evaluated for worker selection. + rng (Generator): RNG object used for randmoness. Defaults to numpy.random.default_rand. + participation (float, optional): Acts as a probability marker for whether or no to include a node. Defaults to 1.0. + always_include_child_aggregators (bool, optional): Ensures whether or not worker nodes are included. Defaults to True. + + Returns: + list[Node]: The selected worker nodes. + """ selected_children = [] for child in children: if child.kind is NodeKind.WORKER and always_include_child_aggregators: diff --git a/flight/strategies/impl/fedasync.py b/flight/strategies/impl/fedasync.py index c8a4c66..ef01ee4 100644 --- a/flight/strategies/impl/fedasync.py +++ b/flight/strategies/impl/fedasync.py @@ -17,6 +17,12 @@ class FedAsyncAggr(DefaultAggrStrategy): + """The aggregator for 'FedAsync' and its respective methods. + + Args: + DefaultAggrStrategy: The base class providing necessary methods for FedAsyncAggr. + """ + def __init__(self, alpha: float = 0.5): assert 0.0 < alpha <= 1.0 self.alpha = alpha @@ -28,6 +34,17 @@ def aggregate_params( children_state_dicts: t.Mapping[NodeID, Params], **kwargs, ) -> Params: + """Method used by aggregator nodes for aggregating the passed node state dictionary. + + Args: + state (NodeState): State of the current aggregator node. + children_states (t.Mapping[NodeID, NodeState]): Dictionary of the states of the children. + children_state_dicts (t.Mapping[NodeID, Params]): Dictionary mapping each child to its values. + **kwargs: Key Word arguments provided by the user. + + Returns: + Params: The aggregated values. + """ last_updated = kwargs.get("last_updated_node", None) assert last_updated is not None assert isinstance(last_updated, int | str) @@ -48,6 +65,13 @@ def aggregate_params( class FedAsync(Strategy): + """Implementation of the FedAsync strategy, which uses default strategies for coordinator, workers, and trainer + and the 'FedAsyncAggr'. + + Args: + Strategy: The base class providing the necessary attributes for 'FedAsync'. + """ + def __init__(self, alpha: float): super().__init__( aggr_strategy=FedAsyncAggr(alpha), diff --git a/flight/strategies/impl/fedavg.py b/flight/strategies/impl/fedavg.py index 4178530..1b8357c 100644 --- a/flight/strategies/impl/fedavg.py +++ b/flight/strategies/impl/fedavg.py @@ -18,6 +18,12 @@ class FedAvgAggr(DefaultAggrStrategy): + """The aggregator for 'FedAvg' and its respective methods. + + Args: + DefaultAggrStrategy: The base class providing necessary methods for 'FedAvgAggr'. + """ + def aggregate_params( self, state: NodeState, @@ -25,6 +31,17 @@ def aggregate_params( children_state_dicts: t.Mapping[NodeID, Params], **kwargs, ) -> Params: + """Method used by aggregator nodes for aggregating the passed node state dictionary. + + Args: + state (NodeState): State of the current aggregator node. + children_states (t.Mapping[NodeID, NodeState]): Dictionary of the states of the children. + children_state_dicts (t.Mapping[NodeID, Params]): Dictionary mapping each child to its values. + **kwargs: Key word arguments provided by the user. + + Returns: + Params: The aggregated values. + """ weights = {} for node, child_state in children_states.items(): weights[node] = child_state["num_data_samples"] @@ -35,14 +52,36 @@ def aggregate_params( class FedAvgWorker(DefaultWorkerStrategy): + """The worker for 'FedAvg' and its respective methods. + + Args: + DefaultWorkerStrategy: The base class providing necessary methods for 'FedAvgWorker' + """ + def before_training( self, state: NodeState, data: Params ) -> tuple[NodeState, Params]: + """Callback to run before the current nodes training. + + Args: + state (NodeState): State of the current worker node. + data (Params): The data related to the current worker node. + + Returns: + tuple[NodeState, Params]: A tuple containing the updated state of the worker node and the data. + """ state["num_data_samples"] = len(data) return state, data class FedAvg(Strategy): + """Implementation of the FedAvg strategy, which uses default strategies for the trainer, + 'FedAvg' for aggregator and workers, and 'FedSGD' for the coordinator. + + Args: + Strategy: The base class providing the necessary attributes for 'FedAvg'. + """ + def __init__( self, participation: float = 1.0, diff --git a/flight/strategies/impl/fedprox.py b/flight/strategies/impl/fedprox.py index cecd927..04b1d8b 100644 --- a/flight/strategies/impl/fedprox.py +++ b/flight/strategies/impl/fedprox.py @@ -16,10 +16,25 @@ class FedProxTrainer(DefaultTrainerStrategy): + """The coordinator and its respective methods for 'FedProx'. + + Args: + DefaultTrainerStrategy: The base class providing necessary methods for 'FedProxTrainer'. + """ + def __init__(self, mu: float = 0.3): self.mu = mu def before_backprop(self, state: NodeState, loss: Loss) -> Loss: + """Callback to run before backpropagation. + + Args: + state (NodeState): The state of the current node. + loss (Loss): The calculated loss associated with the current node. + + Returns: + Loss: The updated loss associated with the current node. + """ global_model = state.global_model local_model = state.local_model assert global_model is not None @@ -47,6 +62,13 @@ def before_backprop(self, state: NodeState, loss: Loss) -> Loss: class FedProx(Strategy): + """Implementation of the FedProx strategy, which uses + 'FedAvg' for the workers, 'FedSGD' for the coordinator and aggregators, and 'FedProx' for the trainer. + + Args: + Strategy: The base class providing the necessary attributes for 'FedProx'. + """ + def __init__( self, mu: float = 0.3, diff --git a/flight/strategies/impl/fedsgd.py b/flight/strategies/impl/fedsgd.py index e4fd565..7e88511 100644 --- a/flight/strategies/impl/fedsgd.py +++ b/flight/strategies/impl/fedsgd.py @@ -19,6 +19,12 @@ class FedSGDCoord(DefaultCoordStrategy): + """The coordinator and its respective methods for 'FedSGD'. + + Args: + DefaultCoordStrategy: The base class providing the necessary methods for 'FedSGDCoord'. + """ + def __init__( self, participation, @@ -32,6 +38,16 @@ def __init__( def select_worker_nodes( self, state: NodeState, workers: t.Iterable[Node], rng: Generator | None = None ) -> t.Sequence[Node]: + """Method containing the method for worker selection for 'FedSGD'. + + Args: + state (NodeState): State of the coordinator node. + workers (t.Iterable[Node]): Iterable containing the worker nodes. + rng (Generator | None, optional): RNG object used for randomness. Defaults to None. + + Returns: + t.Sequence[Node]: The selected worker nodes. + """ selected_workers = random_worker_selection( workers, participation=self.participation, @@ -43,6 +59,12 @@ def select_worker_nodes( class FedSGDAggr(DefaultAggrStrategy): + """The aggregator and its respective methods for 'FedSGD'. + + Args: + DefaultAggrStrategy: The base class providing the necessary methods for 'FedSGDAggr'. + """ + def aggregate_params( self, state: NodeState, @@ -50,10 +72,28 @@ def aggregate_params( children_state_dicts: t.Mapping[NodeID, Params], **kwargs, ) -> Params: + """Method used by aggregator nodes for aggregating the passed node state dictionary. + + Args: + state (NodeState): State of the current aggregator node. + children_states (t.Mapping[NodeID, NodeState]): Dictionary of the states of the children. + children_state_dicts (t.Mapping[NodeID, Params]): Dictionary mapping each child to its values. + **kwargs: Key word arguments provided by the user. + + Returns: + Params: The aggregated values. + """ return average_state_dicts(children_state_dicts, weights=None) class FedSGD(Strategy): + """Implementation of the FedSGD strategy, which uses 'FedSGD' for the coordinator and aggregators, and defaults + for the workers and trainer. + + Args: + Strategy: The base class providing the necessary attributes for 'FedSGD'. + """ + def __init__( self, participation: float = 1.0, diff --git a/tests/strategies/environment.py b/tests/strategies/environment.py index 7409e43..15836ff 100644 --- a/tests/strategies/environment.py +++ b/tests/strategies/environment.py @@ -2,10 +2,17 @@ def create_children(numWorkers: int, numAggr: int = 0) -> list[Node]: + """Creates a fabricated list of children used for coordinator/selecting workers test cases. + + Args: + numWorkers (int): Number of workers to be added. + numAggr (int, optional): Number of aggregators to be added. Defaults to 0. + + Returns: + list[Node]: A list of the created children. + """ aggr = [Node(idx=i, kind=NodeKind.AGGR) for i in range(1, numAggr + 1)] workers = [ Node(idx=i + numAggr, kind=NodeKind.WORKER) for i in range(1, numWorkers + 1) ] - print(workers + aggr) - print(aggr) return workers + aggr diff --git a/tests/strategies/impl/test_fedasync.py b/tests/strategies/impl/test_fedasync.py index 5069a99..2173843 100644 --- a/tests/strategies/impl/test_fedasync.py +++ b/tests/strategies/impl/test_fedasync.py @@ -21,22 +21,22 @@ class TestValidFedAsync: def test_class_hierarchy(self): + """Test that the associated node strategy types follow the correct protocols.""" fedasync = FedAsync(0.5) - assert ( - isinstance(fedasync.aggr_strategy, (AggrStrategy, FedAsyncAggr)) - and isinstance( - fedasync.coord_strategy, (CoordStrategy, DefaultCoordStrategy) - ) - and isinstance( - fedasync.trainer_strategy, (TrainerStrategy, DefaultTrainerStrategy) - ) - and isinstance( - fedasync.worker_strategy, (WorkerStrategy, DefaultWorkerStrategy) - ) + assert isinstance(fedasync.aggr_strategy, (AggrStrategy, FedAsyncAggr)) + assert isinstance( + fedasync.coord_strategy, (CoordStrategy, DefaultCoordStrategy) + ) + assert isinstance( + fedasync.trainer_strategy, (TrainerStrategy, DefaultTrainerStrategy) + ) + assert isinstance( + fedasync.worker_strategy, (WorkerStrategy, DefaultWorkerStrategy) ) def test_fedasync_aggr(self): + """Tests implementation of the aggregator within 'FedAsync'""" strategy = FedAsync(0.5) aggr_strategy: AggrStrategy = strategy.aggr_strategy @@ -62,5 +62,6 @@ def test_fedasync_aggr(self): class TestInvalidFedAsync: def test_invalid_alpha(self): + """Test inputing a value for alpha which is too large.""" with pytest.raises(AssertionError): fedasync = FedAsync(alpha=1.1) diff --git a/tests/strategies/impl/test_fedavg.py b/tests/strategies/impl/test_fedavg.py index f5bfcd9..b3c41a3 100644 --- a/tests/strategies/impl/test_fedavg.py +++ b/tests/strategies/impl/test_fedavg.py @@ -19,18 +19,18 @@ class TestValidFedAvg: def test_fedavg_class_hierarchy(self): + """Test that the associated node strategy types follow the correct protocols.""" fedavg = FedAvg() - assert ( - isinstance(fedavg.aggr_strategy, (AggrStrategy, FedAvgAggr)) - and isinstance(fedavg.coord_strategy, (CoordStrategy, FedSGDCoord)) - and isinstance( - fedavg.trainer_strategy, (TrainerStrategy, DefaultTrainerStrategy) - ) - and isinstance(fedavg.worker_strategy, (WorkerStrategy, FedAvgWorker)) + assert isinstance(fedavg.aggr_strategy, (AggrStrategy, FedAvgAggr)) + assert isinstance(fedavg.coord_strategy, (CoordStrategy, FedSGDCoord)) + assert isinstance( + fedavg.trainer_strategy, (TrainerStrategy, DefaultTrainerStrategy) ) + assert isinstance(fedavg.worker_strategy, (WorkerStrategy, FedAvgWorker)) def test_fedavg_aggr(self): + """Tests the usability of the aggregator strategy for 'FedAvg'""" fedavg = FedAvg() aggregatorStrat: AggrStrategy = fedavg.aggr_strategy nodestate: NodeState = {} @@ -66,6 +66,7 @@ def test_fedavg_aggr(self): assert abs(expected - value.item()) < epsilon def test_fedavg_worker(self): + """Tests the usability of the worker strategy for 'FedAvg'""" fedavg = FedAvg() workerStrat: WorkerStrategy = fedavg.worker_strategy diff --git a/tests/strategies/impl/test_fedprox.py b/tests/strategies/impl/test_fedprox.py index ecc080e..7b92e3d 100644 --- a/tests/strategies/impl/test_fedprox.py +++ b/tests/strategies/impl/test_fedprox.py @@ -11,10 +11,9 @@ class TestValidFedProx: def test_fedprox_class_hierarchy(self): + """Test that the associated node strategy types follow the correct protocols.""" fedprox = FedProx(0.3, 1, False, False) - assert ( - isinstance(fedprox.aggr_strategy, (AggrStrategy, FedSGDAggr)) - and isinstance(fedprox.coord_strategy, (CoordStrategy, FedSGDCoord)) - and isinstance(fedprox.trainer_strategy, (TrainerStrategy, FedProxTrainer)) - and isinstance(fedprox.worker_strategy, (WorkerStrategy, FedAvgWorker)) - ) + assert isinstance(fedprox.aggr_strategy, (AggrStrategy, FedSGDAggr)) + assert isinstance(fedprox.coord_strategy, (CoordStrategy, FedSGDCoord)) + assert isinstance(fedprox.trainer_strategy, (TrainerStrategy, FedProxTrainer)) + assert isinstance(fedprox.worker_strategy, (WorkerStrategy, FedAvgWorker)) diff --git a/tests/strategies/impl/test_fedsgd.py b/tests/strategies/impl/test_fedsgd.py index 7759ba0..a42e209 100644 --- a/tests/strategies/impl/test_fedsgd.py +++ b/tests/strategies/impl/test_fedsgd.py @@ -14,20 +14,20 @@ class TestValidFedSGD: def test_fedsgd_class_hierarchy(self): + """Test that the associated node strategy types follow the correct protocols.""" fedsgd = FedSGD(1, False, True) - assert ( - isinstance(fedsgd.aggr_strategy, (AggrStrategy, FedSGDAggr)) - and isinstance(fedsgd.coord_strategy, (CoordStrategy, FedSGDCoord)) - and isinstance( - fedsgd.trainer_strategy, (TrainerStrategy, DefaultTrainerStrategy) - ) - and isinstance( - fedsgd.worker_strategy, (WorkerStrategy, DefaultWorkerStrategy) - ) + assert isinstance(fedsgd.aggr_strategy, (AggrStrategy, FedSGDAggr)) + assert isinstance(fedsgd.coord_strategy, (CoordStrategy, FedSGDCoord)) + assert isinstance( + fedsgd.trainer_strategy, (TrainerStrategy, DefaultTrainerStrategy) + ) + assert isinstance( + fedsgd.worker_strategy, (WorkerStrategy, DefaultWorkerStrategy) ) def test_default_fedsgd_coord(self): + """Tests the usability of the coordinator strategy for 'FedSGD'""" fedsgd = FedSGD(1, False, True) coordStrat: CoordStrategy = fedsgd.coord_strategy gen = default_rng() @@ -39,6 +39,7 @@ def test_default_fedsgd_coord(self): assert worker in selected def test_fedsgd_aggr(self): + """Tests the usability of the aggregator strategy for 'FedSGD'""" fedsgd = FedSGD(1, False, True) aggrStrat: AggrStrategy = fedsgd.aggr_strategy diff --git a/tests/strategies/test_aggr.py b/tests/strategies/test_aggr.py index 86ec8d6..7627a4a 100644 --- a/tests/strategies/test_aggr.py +++ b/tests/strategies/test_aggr.py @@ -7,12 +7,14 @@ def test_instance(): + """Test that the associated node strategy type follows the correct protocols.""" default_aggr = DefaultAggrStrategy() assert isinstance(default_aggr, AggrStrategy) def test_aggr_aggregate_params(): + """Tests usability for the 'aggregate_params' function on two children.""" default_aggr = DefaultAggrStrategy() state: NodeState = "foo" diff --git a/tests/strategies/test_coord.py b/tests/strategies/test_coord.py index 5aec9f1..6d7508d 100644 --- a/tests/strategies/test_coord.py +++ b/tests/strategies/test_coord.py @@ -8,31 +8,33 @@ def test_instance(): + """Test that the associated node strategy type follows the correct protocols.""" default_coord = DefaultCoordStrategy() assert isinstance(default_coord, CoordStrategy) def test_worker_selection(): + """Tests both fix and probabilistic worker selection on five workers.""" gen = default_rng() - - children = create_children(numWorkers=2) - # fixed random - fixed_random = random_worker_selection( - children, - participation=1, - probabilistic=False, - always_include_child_aggregators=True, - rng=gen, - ) - # prob random - prob_random = random_worker_selection( - children, - participation=1, - probabilistic=True, - always_include_child_aggregators=True, - rng=gen, - ) + for _ in range(5): + children = create_children(numWorkers=5) + # fixed random + fixed_random = random_worker_selection( + children, + participation=1, + probabilistic=False, + always_include_child_aggregators=True, + rng=gen, + ) + # prob random + prob_random = random_worker_selection( + children, + participation=1, + probabilistic=True, + always_include_child_aggregators=True, + rng=gen, + ) for child in children: assert child in fixed_random and child in prob_random @@ -40,6 +42,7 @@ def test_worker_selection(): class TestInvalidFixedSelection: def test_fixed_random(self): + """Tests an invalid level of participation on fixed selection raises a 'ValueError'""" gen = default_rng() children = create_children(numWorkers=5) @@ -54,6 +57,7 @@ def test_fixed_random(self): ) def test_fixed_random_mix(self): + """Tests an invalid level of participation on prob selection raises a 'ValueError'""" gen = default_rng() children = create_children(numWorkers=1, numAggr=2) diff --git a/tests/strategies/test_worker.py b/tests/strategies/test_worker.py index a38e04b..8c31ce1 100644 --- a/tests/strategies/test_worker.py +++ b/tests/strategies/test_worker.py @@ -11,22 +11,7 @@ def test_instance(): + """Test that the associated node strategy type follows the correct protocols.""" default_worker = DefaultWorkerStrategy() assert isinstance(default_worker, WorkerStrategy) - - -def test_default_methods(): - default_worker = DefaultWorkerStrategy() - workerstate = ("NodeID 1", "num_data_samples 2") - data = { - "train/acc": torch.tensor(0.3, dtype=torch.float32), - "train/loss": torch.tensor(0.7, dtype=torch.float32), - } - # result = Result(workerstate, 1, data, [], {}) - # optimizer = torch.optim.Optimizer(data.values(), {}) - - # assert default_worker.start_work(workerstate) == workerstate - # assert default_worker.before_training(workerstate, data) == (workerstate, data) - # assert default_worker.after_training(workerstate, optimizer) == workerstate - # assert default_worker.end_work(result) == result From 868a963652bae4335db65b9acb3b55a0575195f7 Mon Sep 17 00:00:00 2001 From: Nathaniel Hudson Date: Thu, 18 Jul 2024 22:07:02 -0500 Subject: [PATCH 05/11] Update averaging.py --- flight/strategies/commons/averaging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flight/strategies/commons/averaging.py b/flight/strategies/commons/averaging.py index 6a2edf5..eb44c44 100644 --- a/flight/strategies/commons/averaging.py +++ b/flight/strategies/commons/averaging.py @@ -22,7 +22,7 @@ def average_state_dicts( weights (Mapping[NodeID, float] | None, optional): Optional dictionary that maps each node to its contribution factor. Defaults to None. Returns: - Params: The averaged states. + Params: The averaged parameters. """ num_nodes = len(state_dicts) From 38efa5403492e6d5a69a44658fc627a74f1039e9 Mon Sep 17 00:00:00 2001 From: Nathaniel Hudson Date: Thu, 18 Jul 2024 22:08:51 -0500 Subject: [PATCH 06/11] Update fedavg.py --- flight/strategies/impl/fedavg.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/flight/strategies/impl/fedavg.py b/flight/strategies/impl/fedavg.py index 1b8357c..3bd3b2b 100644 --- a/flight/strategies/impl/fedavg.py +++ b/flight/strategies/impl/fedavg.py @@ -18,11 +18,7 @@ class FedAvgAggr(DefaultAggrStrategy): - """The aggregator for 'FedAvg' and its respective methods. - - Args: - DefaultAggrStrategy: The base class providing necessary methods for 'FedAvgAggr'. - """ + """The aggregator for the FedAvg algorithm and its respective methods.""" def aggregate_params( self, From b162e6908cd9ef2b4974eff06de4c4a853e5f2aa Mon Sep 17 00:00:00 2001 From: Nathaniel Hudson Date: Thu, 18 Jul 2024 22:09:29 -0500 Subject: [PATCH 07/11] Update fedavg.py --- flight/strategies/impl/fedavg.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/flight/strategies/impl/fedavg.py b/flight/strategies/impl/fedavg.py index 3bd3b2b..037ec27 100644 --- a/flight/strategies/impl/fedavg.py +++ b/flight/strategies/impl/fedavg.py @@ -48,11 +48,7 @@ def aggregate_params( class FedAvgWorker(DefaultWorkerStrategy): - """The worker for 'FedAvg' and its respective methods. - - Args: - DefaultWorkerStrategy: The base class providing necessary methods for 'FedAvgWorker' - """ + """The worker for 'FedAvg' and its respective methods.""" def before_training( self, state: NodeState, data: Params @@ -71,11 +67,9 @@ def before_training( class FedAvg(Strategy): - """Implementation of the FedAvg strategy, which uses default strategies for the trainer, - 'FedAvg' for aggregator and workers, and 'FedSGD' for the coordinator. - - Args: - Strategy: The base class providing the necessary attributes for 'FedAvg'. + """ + Implementation of the FedAvg strategy, which uses default strategies for the trainer, + 'FedAvg' for aggregator and workers, and 'FedSGD' for the coordinator. """ def __init__( From cddc845d8c8df21e970c11b8bbac767ed104379c Mon Sep 17 00:00:00 2001 From: Nathaniel Hudson Date: Thu, 18 Jul 2024 22:10:38 -0500 Subject: [PATCH 08/11] Update environment.py --- tests/strategies/environment.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/strategies/environment.py b/tests/strategies/environment.py index 15836ff..f6a9b1a 100644 --- a/tests/strategies/environment.py +++ b/tests/strategies/environment.py @@ -1,18 +1,18 @@ from flight.federation.topologies.node import Node, NodeKind -def create_children(numWorkers: int, numAggr: int = 0) -> list[Node]: +def create_children(num_workers: int, num_aggrs: int = 0) -> list[Node]: """Creates a fabricated list of children used for coordinator/selecting workers test cases. Args: - numWorkers (int): Number of workers to be added. - numAggr (int, optional): Number of aggregators to be added. Defaults to 0. + num_workers (int): Number of workers to be added. + num_aggrs (int, optional): Number of aggregators to be added. Defaults to 0. Returns: list[Node]: A list of the created children. """ - aggr = [Node(idx=i, kind=NodeKind.AGGR) for i in range(1, numAggr + 1)] + aggr = [Node(idx=i, kind=NodeKind.AGGR) for i in range(1, num_aggrs + 1)] workers = [ - Node(idx=i + numAggr, kind=NodeKind.WORKER) for i in range(1, numWorkers + 1) + Node(idx=i + numAggr, kind=NodeKind.WORKER) for i in range(1, num_workers + 1) ] return workers + aggr From 57c05e2e2c8a7a57c4736b140881d65cf5eabad8 Mon Sep 17 00:00:00 2001 From: Nathaniel Hudson Date: Thu, 18 Jul 2024 22:11:19 -0500 Subject: [PATCH 09/11] Update environment.py --- tests/strategies/environment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/strategies/environment.py b/tests/strategies/environment.py index f6a9b1a..d91a000 100644 --- a/tests/strategies/environment.py +++ b/tests/strategies/environment.py @@ -13,6 +13,6 @@ def create_children(num_workers: int, num_aggrs: int = 0) -> list[Node]: """ aggr = [Node(idx=i, kind=NodeKind.AGGR) for i in range(1, num_aggrs + 1)] workers = [ - Node(idx=i + numAggr, kind=NodeKind.WORKER) for i in range(1, num_workers + 1) + Node(idx=i + num_aggrs, kind=NodeKind.WORKER) for i in range(1, num_workers + 1) ] return workers + aggr From 9cadb2f63fa8a4eebf7b57cf9b1557945c4ecc59 Mon Sep 17 00:00:00 2001 From: Nathaniel Hudson Date: Thu, 18 Jul 2024 22:11:45 -0500 Subject: [PATCH 10/11] Update test_coord.py --- tests/strategies/test_coord.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/strategies/test_coord.py b/tests/strategies/test_coord.py index 6d7508d..9c02591 100644 --- a/tests/strategies/test_coord.py +++ b/tests/strategies/test_coord.py @@ -60,7 +60,7 @@ def test_fixed_random_mix(self): """Tests an invalid level of participation on prob selection raises a 'ValueError'""" gen = default_rng() - children = create_children(numWorkers=1, numAggr=2) + children = create_children(num_workers=1, num_aggrs=2) with pytest.raises(ValueError): fixed_random = random_worker_selection( From 84641d1e8fe68da622978ab3a42ea386c7ecc412 Mon Sep 17 00:00:00 2001 From: Nathaniel Hudson Date: Thu, 18 Jul 2024 22:12:50 -0500 Subject: [PATCH 11/11] Update fedsgd.py --- flight/strategies/impl/fedsgd.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/flight/strategies/impl/fedsgd.py b/flight/strategies/impl/fedsgd.py index 7e88511..6ad184b 100644 --- a/flight/strategies/impl/fedsgd.py +++ b/flight/strategies/impl/fedsgd.py @@ -19,11 +19,7 @@ class FedSGDCoord(DefaultCoordStrategy): - """The coordinator and its respective methods for 'FedSGD'. - - Args: - DefaultCoordStrategy: The base class providing the necessary methods for 'FedSGDCoord'. - """ + """The coordinator and its respective methods for 'FedSGD'.""" def __init__( self, @@ -87,11 +83,9 @@ def aggregate_params( class FedSGD(Strategy): - """Implementation of the FedSGD strategy, which uses 'FedSGD' for the coordinator and aggregators, and defaults - for the workers and trainer. - - Args: - Strategy: The base class providing the necessary attributes for 'FedSGD'. + """ + Implementation of the FedSGD strategy, which uses 'FedSGD' for the coordinator and aggregators, and defaults + for the workers and trainer. """ def __init__(