diff --git a/flight/federation/jobs/result.py b/flight/federation/jobs/result.py index fb0050e..cdc05b2 100644 --- a/flight/federation/jobs/result.py +++ b/flight/federation/jobs/result.py @@ -4,14 +4,14 @@ from pydantic.dataclasses import dataclass from flight.learning.module import RecordList -from flight.strategies.aggr import Params +from flight.strategies import Params -if t.TYPE_CHECKING: - NodeID: t.TypeAlias = t.Hashable - NodeState: t.TypeAlias = tuple +NodeID: t.TypeAlias = t.Hashable +NodeState: t.TypeAlias = tuple -@dataclass +# TODO: Remove config when all type definitions have been resolvedß +@dataclass(config={"arbitrary_types_allowed": True}) class Result: state: NodeState node_idx: NodeID diff --git a/flight/federation/topologies/io.py b/flight/federation/topologies/io.py index 7cb432a..935ad64 100644 --- a/flight/federation/topologies/io.py +++ b/flight/federation/topologies/io.py @@ -9,6 +9,7 @@ wish to use `from_yaml()` to create a Topology with a YAML file, we encourage the use the `Topology.from_yaml()` method instead. """ + from __future__ import annotations import json diff --git a/flight/strategies/__init__.py b/flight/strategies/__init__.py index e69de29..ef5f39f 100644 --- a/flight/strategies/__init__.py +++ b/flight/strategies/__init__.py @@ -0,0 +1,62 @@ +import typing as t + +import torch + +from flight.strategies.aggr import AggrStrategy +from flight.strategies.base import DefaultStrategy, Strategy +from flight.strategies.coord import CoordStrategy +from flight.strategies.trainer import TrainerStrategy +from flight.strategies.worker import WorkerStrategy + +Loss: t.TypeAlias = torch.Tensor +Params: t.TypeAlias = dict[str, torch.Tensor] +NodeState: t.TypeAlias = t.Any + + +def load_strategy(strategy_name: str, **kwargs) -> Strategy: + """Function used to grab the users preferred 'Strategy'. + + Args: + strategy_name (str): The name of the 'Strategy' to be grabbed. + + Raises: + ValueError: If an unknown 'Strategy' type is passed through. + + Returns: + Strategy: The selected 'Strategy' type. + """ + assert isinstance(strategy_name, str), "`strategy_name` must be a string." + match strategy_name.lower(): + case "default": + return DefaultStrategy() + + case "fedasync" | "fed-async": + from flight.strategies.impl.fedasync import FedAsync + + return FedAsync(**kwargs) + + case "fedavg" | "fed-avg": + from flight.strategies.impl.fedavg import FedAvg + + return FedAvg(**kwargs) + + case "fedprox" | "fed-prox": + from flight.strategies.impl.fedprox import FedProx + + return FedProx(**kwargs) + + case "fedsgd" | "fed-sgd": + from flight.strategies.impl.fedsgd import FedSGD + + return FedSGD(**kwargs) + case _: + raise ValueError(f"Strategy '{strategy_name}' is not recognized.") + + +__all__ = [ + "AggrStrategy", + "Strategy", + "CoordStrategy", + "TrainerStrategy", + "WorkerStrategy", +] diff --git a/flight/strategies/aggr.py b/flight/strategies/aggr.py index 81c4f28..e59f6fe 100644 --- a/flight/strategies/aggr.py +++ b/flight/strategies/aggr.py @@ -1,15 +1,40 @@ +from __future__ import annotations + import typing as t if t.TYPE_CHECKING: - Params: t.TypeAlias = t.Any + from flight.federation.topologies.node import NodeID + from flight.strategies import NodeState, Params +@t.runtime_checkable class AggrStrategy(t.Protocol): + """Template for all aggregator strategies, including those defined in Flight and those defined by Users.""" + def start_round(self): + """Callback to run at the start of a round.""" pass - def aggregate_params(self) -> Params: + def aggregate_params( + self, + state: NodeState, + children_states: t.Mapping[NodeID, NodeState], + children_state_dicts: t.Mapping[NodeID, Params], + **kwargs, + ) -> Params: + """Callback that handles the model parameter aggregation step. + + Args: + state (NodeState): The state of the current aggregator node. + children_states (t.Mapping[NodeID, NodeState]): A mapping of the current aggregator node's children and their respective states. + children_state_dicts (t.Mapping[NodeID, Parmas]): The model parameters of the models to each respective child node. + **kwargs: Keyword arguments provided by users. + + Returns: + Params: The aggregated parameters to update the model at the current aggregator. + """ pass def end_round(self): + """Callback to run at the end of a round.""" pass diff --git a/flight/strategies/base.py b/flight/strategies/base.py index f6679ed..e834575 100644 --- a/flight/strategies/base.py +++ b/flight/strategies/base.py @@ -1,35 +1,221 @@ -import functools -import typing as t - -import pydantic as pyd - - -@pyd.dataclasses.dataclass(frozen=True, repr=False) -class Strategy: - coord_strategy: str = pyd.field() - aggr_strategy: str = pyd.field() - worker_strategy: str = pyd.field() - trainer_strategy: str = pyd.field() - - def __iter__(self) -> t.Iterator[tuple[str, t.Any]]: - yield from ( - ("coord_strategy", self.coord_strategy), - ("aggr_strategy", self.aggr_strategy), - ("worker_strategy", self.worker_strategy), - ("trainer_strategy", self.trainer_strategy), - ) - - def __repr__(self) -> str: - return str(self) - - @functools.cached_property - def __str__(self) -> str: - name = self.__class__.__name__ - inner = ", ".join( - [ - f"{strategy_key}={strategy_value.__class__.__name__}" - for (strategy_key, strategy_value) in iter(self) - if strategy_value is not None - ] - ) - return f"{name}({inner})" +from __future__ import annotations + +import functools +import typing as t + +import pydantic as pyd + +from flight.strategies.aggr import AggrStrategy +from flight.strategies.commons.averaging import average_state_dicts +from flight.strategies.coord import CoordStrategy +from flight.strategies.trainer import TrainerStrategy +from flight.strategies.worker import WorkerStrategy + +StrategyType: t.TypeAlias = ( + WorkerStrategy | AggrStrategy | CoordStrategy | TrainerStrategy +) + +if t.TYPE_CHECKING: + import torch + from numpy.random import Generator + + from flight.federation.jobs.result import Result + from flight.federation.topologies.node import Node, NodeID + from flight.strategies import Loss, NodeState, Params + + +class DefaultCoordStrategy: + """Default implementation of the strategy for a coordinator.""" + + def select_workers( + self, state: NodeState, workers: t.Iterable[Node], rng: Generator + ) -> t.Sequence[Node]: + """Method used for the selection of workers. + + Args: + state (NodeState): The state of the coordinator node. + workers (t.Iterable[Node]): Iterable object containing all of the worker nodes. + rng (Generator): RNG object used for randomness. + + Returns: + t.Sequence[Node]: The selected workers. + """ + return list(workers) + + +class DefaultAggrStrategy: + """Default implementation of the strategy for an aggregator.""" + + def start_round(self): + pass + + def aggregate_params( + self, + state: NodeState, + children_states: t.Mapping[NodeID, NodeState], + children_state_dicts: t.Mapping[NodeID, Params], + **kwargs, + ) -> Params: + """Callback that handles the model parameter aggregation step. + + Args: + state (NodeState): The state of the current aggregator node. + children_states (t.Mapping[NodeID, NodeState]): A mapping of the current aggregator node's children and their respective states. + children_state_dicts (t.Mapping[NodeID, Parmas]): The model parameters of the models to each respective child node. + **kwargs: Keyword arguments provided by users. + + Returns: + Params: The aggregated values. + """ + return average_state_dicts(children_state_dicts, weights=None) + + def end_round(self): + pass + + +class DefaultWorkerStrategy: + """Default implementation of the strategy for a worker""" + + def start_work(self, state: NodeState) -> NodeState: + """Callback to be ran and the start of the current worker nodes work. + + Args: + state (NodeState): The state of the current worker node. + + Returns: + NodeState: The state of the current worker node at the end of the callback. + """ + return state + + def before_training( + self, state: NodeState, data: Params + ) -> tuple[NodeState, Params]: + """Callback to be ran before training. + + Args: + state (NodeState): The state of the current worker node. + data (Params): The data associated with the current worker node. + + Returns: + tuple[NodeState, Params]: A tuple containing the state and data of the worker node at the end of the callback. + """ + return state, data + + def after_training( + self, state: NodeState, optimizer: torch.optim.Optimizer + ) -> NodeState: + """Callback to be ran after training. + + Args: + state (NodeState): The state of the current worker node. + optimizer (torch.optim.Optimizer): The PyTorch optimier to be used. + + Returns: + NodeState: The state of the worker node at the end of the callback. + """ + return state + + def end_work(self, result: Result) -> Result: + """Callback to be ran at the end of the work. + + Args: + result (Result): A Result object used to represent the result of the local training on the current worker node. + + Returns: + Result: The result of the worker nodes local training. + """ + return result + + +class DefaultTrainerStrategy: + """Default implementation of a strategy for the trainer.""" + + def before_backprop(self, state: NodeState, loss: Loss) -> Loss: + """Callback to run before backpropagation. + + Args: + state (NodeState): State of the current node. + loss (Loss): The calculated loss + + Returns: + The loss at the end of the callback + """ + return loss + + def after_backprop(self, state: NodeState, loss: Loss) -> Loss: + """Callback to run after backpropagation. + + Args: + state (NodeState): State of the current node. + loss (Loss): The calculated loss + + Returns: + The loss at the end of the callback + """ + return loss + + +# TODO: Remove config when all type definitions have been resolved +@pyd.dataclasses.dataclass( + frozen=True, repr=False, config={"arbitrary_types_allowed": True} +) +class Strategy: + """ + A 'Strategy' implementation is comprised of the four different type of implementations of strategies + to be used on the respective node types throughout the training process. + """ + + """Implementation of the specific callbacks for the coordinator node.""" + coord_strategy: CoordStrategy = pyd.Field() + """Implementation of the specific callbacks for the aggregator node(s).""" + aggr_strategy: AggrStrategy = pyd.Field() + """Implementation of the specific callbacks for the worker node(s).""" + worker_strategy: WorkerStrategy = pyd.Field() + """Implementation of callbacks specific to the execution of the training loop on the worker node(s).""" + trainer_strategy: TrainerStrategy = pyd.Field() + + def __iter__(self) -> t.Iterator[tuple[str, StrategyType]]: + yield from ( + ("coord_strategy", self.coord_strategy), + ("aggr_strategy", self.aggr_strategy), + ("worker_strategy", self.worker_strategy), + ("trainer_strategy", self.trainer_strategy), + ) + + def __repr__(self) -> str: + return str(self) + + @functools.cached_property + def _description(self) -> str: + """A utility function for generating the string for `__str__`. + + This is written to avoid the `mypy` issue: + "Signature of '__str__' incompatible with supertype 'object'". + + Returns: + The string representation of the a Strategy instance. + """ + name = self.__class__.__name__ + inner = ", ".join( + [ + f"{strategy_key}={strategy_value.__class__.__name__}" + for (strategy_key, strategy_value) in iter(self) + if strategy_value is not None + ] + ) + return f"{name}({inner})" + + def __str__(self) -> str: + return self._description + + +class DefaultStrategy(Strategy): + """Implementation of a strategy that uses the default strategy types for each node type.""" + + def __init__(self) -> None: + super().__init__( + coord_strategy=DefaultCoordStrategy(), + aggr_strategy=DefaultAggrStrategy(), + worker_strategy=DefaultWorkerStrategy(), + trainer_strategy=DefaultTrainerStrategy(), + ) diff --git a/flight/strategies/commons/__init__.py b/flight/strategies/commons/__init__.py new file mode 100644 index 0000000..6e3e270 --- /dev/null +++ b/flight/strategies/commons/__init__.py @@ -0,0 +1,4 @@ +from flight.strategies.commons.averaging import average_state_dicts +from flight.strategies.commons.worker_selection import random_worker_selection + +__all__ = ["average_state_dicts", "random_worker_selection"] diff --git a/flight/strategies/commons/averaging.py b/flight/strategies/commons/averaging.py new file mode 100644 index 0000000..eb44c44 --- /dev/null +++ b/flight/strategies/commons/averaging.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import typing as t + +import numpy +import torch + +if t.TYPE_CHECKING: + from collections.abc import Mapping + + NodeID: t.TypeAlias = t.Any + Params: t.TypeAlias = t.Any + + +def average_state_dicts( + state_dicts: Mapping[NodeID, Params], weights: Mapping[NodeID, float] | None = None +) -> Params: + """Helper function used by aggregator nodes for averaging the passed node state dictionary. + + Args: + state_dicts (Mapping[NodeID, Params]): A dictionary object mapping nodes to their respective states. + weights (Mapping[NodeID, float] | None, optional): Optional dictionary that maps each node to its contribution factor. Defaults to None. + + Returns: + Params: The averaged parameters. + """ + num_nodes = len(state_dicts) + + if weights is not None: + weight_sum = numpy.sum(list(weights.values())) + else: + weight_sum = None + + with torch.no_grad(): + avg_weights = {} + for node, state_dict in state_dicts.items(): + if weights is not None: + w = weights[node] / weight_sum + else: + w = 1 / num_nodes + for name, value in state_dict.items(): + value = w * torch.clone(value) + if name not in avg_weights: + avg_weights[name] = value + else: + avg_weights[name] += value + + return avg_weights diff --git a/flight/strategies/commons/worker_selection.py b/flight/strategies/commons/worker_selection.py new file mode 100644 index 0000000..26027f6 --- /dev/null +++ b/flight/strategies/commons/worker_selection.py @@ -0,0 +1,85 @@ +from collections.abc import Iterable + +from numpy import array +from numpy.random import Generator, default_rng + +from flight.federation.topologies.node import Node, NodeKind + + +def random_worker_selection( + children: Iterable[Node], + participation: float = 1.0, + probabilistic: bool = False, + always_include_child_aggregators: bool = True, + rng: Generator | None = None, +) -> list[Node]: + """General call for worker selection that will then choose from probabilistic or fixed selection. + + Args: + children (Iterable[Node]): Children to be evaluated for worker selection. + participation (float, optional): Controls the level of participation each node contributes. Defaults to 1.0. + probabilistic (bool, optional): Decider for whether probabilistic (True), or fixed (False) selection will be used. Defaults to False. + always_include_child_aggregators (bool, optional): In probabilistic selection, ensures whether or not all worker nodes are included. Defaults to True. + rng (Generator | None, optional): RNG object used for randomness, numpy.random.default_rng will be used if None. Defaults to None. + + Returns: + list[Node]: The selected worker nodes. + """ + if rng is None: + rng = default_rng() + if probabilistic: + return prob_random_worker_selection( + children, rng, participation, always_include_child_aggregators + ) + return fixed_random_worker_selection(children, rng, participation) + + +def fixed_random_worker_selection( + children: Iterable[Node], rng: Generator, participation: float = 1.0 +) -> list[Node]: + """The worker selection used when probalistic is false. This worker selection is entirely random based on 'rng' + + Args: + children (Iterable[Node]): Children to be evaluated for worker selection. + rng (Generator): RNG object used for randomness. Defaults to numpy.random.default_rand. + participation (float, optional): Controls the level of participation each node contributes. Defaults to 1.0. + + Returns: + list[Node]: The selected worker nodes. + """ + children = array(children) + num_selected = max(1, int(participation * len(list(children)))) + selected_children = rng.choice(children, size=num_selected, replace=False) + return list(selected_children) + + +def prob_random_worker_selection( + children: Iterable[Node], + rng: Generator, + participation: float = 1.0, + always_include_child_aggregators: bool = True, +) -> list[Node]: + """The worker selection used when probalistic is true. This worker selection is probalistic and therefore in most cases + will select workers in order. + + Args: + children (Iterable[Node]): Children to be evaluated for worker selection. + rng (Generator): RNG object used for randmoness. Defaults to numpy.random.default_rand. + participation (float, optional): Acts as a probability marker for whether or no to include a node. Defaults to 1.0. + always_include_child_aggregators (bool, optional): Ensures whether or not worker nodes are included. Defaults to True. + + Returns: + list[Node]: The selected worker nodes. + """ + selected_children = [] + for child in children: + if child.kind is NodeKind.WORKER and always_include_child_aggregators: + selected_children.append(child) + elif rng.uniform() <= participation: + selected_children.append(child) + + if len(selected_children) == 0: + random_child = rng.choice(array(children)) + selected_children.append(random_child) + + return selected_children diff --git a/flight/strategies/coord.py b/flight/strategies/coord.py index 7314e3d..3ff8ddc 100644 --- a/flight/strategies/coord.py +++ b/flight/strategies/coord.py @@ -5,13 +5,24 @@ if t.TYPE_CHECKING: from numpy.random import Generator - from ..federation.topologies.node import Node - - CoordState: t.TypeAlias = t.Any + from flight.federation.topologies.node import Node + from flight.strategies import NodeState +@t.runtime_checkable class CoordStrategy(t.Protocol): + """Template for all coordinator strategies, including those defined in Flight and those defined by Users.""" + def select_workers( - self, state: CoordState, workers: t.Iterable[Node], rng: Generator + self, state: NodeState, workers: t.Iterable[Node], rng: Generator ) -> t.Sequence[Node]: + """Callback that is responsible for selecting a subset of worker nodes to do local training. + + Args: + state (NodeState): The state of the current coordinator node. + children (t.Iterable[Node]): The worker nodes in the topology. + rng (Generator): The rng used for reproducibility. + + Returns: + The selected worker nodes.""" pass diff --git a/flight/strategies/impl/__init__.py b/flight/strategies/impl/__init__.py index e69de29..28bbd08 100644 --- a/flight/strategies/impl/__init__.py +++ b/flight/strategies/impl/__init__.py @@ -0,0 +1,6 @@ +from flight.strategies.impl.fedasync import FedAsync +from flight.strategies.impl.fedavg import FedAvg +from flight.strategies.impl.fedprox import FedProx +from flight.strategies.impl.fedsgd import FedSGD + +__all__ = ["FedAsync", "FedAvg", "FedProx", "FedSGD"] diff --git a/flight/strategies/impl/fedasync.py b/flight/strategies/impl/fedasync.py new file mode 100644 index 0000000..ef01ee4 --- /dev/null +++ b/flight/strategies/impl/fedasync.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +import typing as t +from collections import OrderedDict + +from flight.strategies.base import ( + DefaultAggrStrategy, + DefaultCoordStrategy, + DefaultTrainerStrategy, + DefaultWorkerStrategy, + Strategy, +) + +if t.TYPE_CHECKING: + from flight.federation.topologies.node import NodeID + from flight.strategies import NodeState, Params + + +class FedAsyncAggr(DefaultAggrStrategy): + """The aggregator for 'FedAsync' and its respective methods. + + Args: + DefaultAggrStrategy: The base class providing necessary methods for FedAsyncAggr. + """ + + def __init__(self, alpha: float = 0.5): + assert 0.0 < alpha <= 1.0 + self.alpha = alpha + + def aggregate_params( + self, + state: NodeState, + children_states: t.Mapping[NodeID, NodeState], + children_state_dicts: t.Mapping[NodeID, Params], + **kwargs, + ) -> Params: + """Method used by aggregator nodes for aggregating the passed node state dictionary. + + Args: + state (NodeState): State of the current aggregator node. + children_states (t.Mapping[NodeID, NodeState]): Dictionary of the states of the children. + children_state_dicts (t.Mapping[NodeID, Params]): Dictionary mapping each child to its values. + **kwargs: Key Word arguments provided by the user. + + Returns: + Params: The aggregated values. + """ + last_updated = kwargs.get("last_updated_node", None) + assert last_updated is not None + assert isinstance(last_updated, int | str) + + global_model_params = state.global_model.state_dict() + last_updated_params = children_state_dicts[last_updated] + + aggr_params = [] + for param in global_model_params: + w0, w = ( + global_model_params[param].detach(), + last_updated_params[param].detach(), + ) + aggr_w = w0 * (1 - self.alpha) + w * self.alpha + aggr_params.append((param, aggr_w)) + + return OrderedDict(aggr_params) + + +class FedAsync(Strategy): + """Implementation of the FedAsync strategy, which uses default strategies for coordinator, workers, and trainer + and the 'FedAsyncAggr'. + + Args: + Strategy: The base class providing the necessary attributes for 'FedAsync'. + """ + + def __init__(self, alpha: float): + super().__init__( + aggr_strategy=FedAsyncAggr(alpha), + coord_strategy=DefaultCoordStrategy(), + worker_strategy=DefaultWorkerStrategy(), + trainer_strategy=DefaultTrainerStrategy(), + ) diff --git a/flight/strategies/impl/fedavg.py b/flight/strategies/impl/fedavg.py new file mode 100644 index 0000000..037ec27 --- /dev/null +++ b/flight/strategies/impl/fedavg.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import typing as t + +from flight.strategies.base import ( + DefaultAggrStrategy, + DefaultTrainerStrategy, + DefaultWorkerStrategy, + Strategy, +) +from flight.strategies.commons import average_state_dicts + +from .fedsgd import FedSGDCoord + +if t.TYPE_CHECKING: + from flight.federation.topologies.node import NodeID + from flight.strategies import NodeState, Params + + +class FedAvgAggr(DefaultAggrStrategy): + """The aggregator for the FedAvg algorithm and its respective methods.""" + + def aggregate_params( + self, + state: NodeState, + children_states: t.Mapping[NodeID, NodeState], + children_state_dicts: t.Mapping[NodeID, Params], + **kwargs, + ) -> Params: + """Method used by aggregator nodes for aggregating the passed node state dictionary. + + Args: + state (NodeState): State of the current aggregator node. + children_states (t.Mapping[NodeID, NodeState]): Dictionary of the states of the children. + children_state_dicts (t.Mapping[NodeID, Params]): Dictionary mapping each child to its values. + **kwargs: Key word arguments provided by the user. + + Returns: + Params: The aggregated values. + """ + weights = {} + for node, child_state in children_states.items(): + weights[node] = child_state["num_data_samples"] + + state["num_data_samples"] = sum(weights.values()) + + return average_state_dicts(children_state_dicts, weights=weights) + + +class FedAvgWorker(DefaultWorkerStrategy): + """The worker for 'FedAvg' and its respective methods.""" + + def before_training( + self, state: NodeState, data: Params + ) -> tuple[NodeState, Params]: + """Callback to run before the current nodes training. + + Args: + state (NodeState): State of the current worker node. + data (Params): The data related to the current worker node. + + Returns: + tuple[NodeState, Params]: A tuple containing the updated state of the worker node and the data. + """ + state["num_data_samples"] = len(data) + return state, data + + +class FedAvg(Strategy): + """ + Implementation of the FedAvg strategy, which uses default strategies for the trainer, + 'FedAvg' for aggregator and workers, and 'FedSGD' for the coordinator. + """ + + def __init__( + self, + participation: float = 1.0, + probabilistic: bool = False, + always_include_child_aggregators: bool = False, + ): + super().__init__( + coord_strategy=FedSGDCoord( + participation, probabilistic, always_include_child_aggregators + ), + aggr_strategy=FedAvgAggr(), + worker_strategy=FedAvgWorker(), + trainer_strategy=DefaultTrainerStrategy(), + ) diff --git a/flight/strategies/impl/fedprox.py b/flight/strategies/impl/fedprox.py new file mode 100644 index 0000000..04b1d8b --- /dev/null +++ b/flight/strategies/impl/fedprox.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import typing as t + +import torch + +from flight.strategies.base import DefaultTrainerStrategy, Strategy + +from .fedavg import FedAvgWorker +from .fedsgd import FedSGDAggr, FedSGDCoord + +if t.TYPE_CHECKING: + from flight.strategies import Loss, NodeState + +DEVICE = "cpu" + + +class FedProxTrainer(DefaultTrainerStrategy): + """The coordinator and its respective methods for 'FedProx'. + + Args: + DefaultTrainerStrategy: The base class providing necessary methods for 'FedProxTrainer'. + """ + + def __init__(self, mu: float = 0.3): + self.mu = mu + + def before_backprop(self, state: NodeState, loss: Loss) -> Loss: + """Callback to run before backpropagation. + + Args: + state (NodeState): The state of the current node. + loss (Loss): The calculated loss associated with the current node. + + Returns: + Loss: The updated loss associated with the current node. + """ + global_model = state.global_model + local_model = state.local_model + assert global_model is not None + assert local_model is not None + + global_model = global_model.to(DEVICE) + local_model = local_model.to(DEVICE) + + params = list(local_model.state_dict().values()) + params0 = list(global_model.state_dict().values()) + + proximal_diff = torch.tensor( + [ + torch.sum(torch.pow(params[i] - params0[i], 2)) + for i in range(len(params)) + ] + ) + proximal_term = torch.sum(proximal_diff) + proximal_term = proximal_term * self.mu / 2 + + proximal_term = proximal_term.to(DEVICE) + + loss += proximal_term + return loss + + +class FedProx(Strategy): + """Implementation of the FedProx strategy, which uses + 'FedAvg' for the workers, 'FedSGD' for the coordinator and aggregators, and 'FedProx' for the trainer. + + Args: + Strategy: The base class providing the necessary attributes for 'FedProx'. + """ + + def __init__( + self, + mu: float = 0.3, + participation: float = 1.0, + probabilistic: bool = False, + always_include_child_aggregators: bool = False, + ): + super().__init__( + coord_strategy=FedSGDCoord( + participation, + probabilistic, + always_include_child_aggregators, + ), + aggr_strategy=FedSGDAggr(), + worker_strategy=FedAvgWorker(), + trainer_strategy=FedProxTrainer(mu=mu), + ) diff --git a/flight/strategies/impl/fedsgd.py b/flight/strategies/impl/fedsgd.py new file mode 100644 index 0000000..6ad184b --- /dev/null +++ b/flight/strategies/impl/fedsgd.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +import typing as t + +from numpy.random import Generator + +from flight.strategies import Strategy +from flight.strategies.base import ( + DefaultAggrStrategy, + DefaultCoordStrategy, + DefaultTrainerStrategy, + DefaultWorkerStrategy, +) +from flight.strategies.commons import average_state_dicts, random_worker_selection + +if t.TYPE_CHECKING: + from flight.federation.topologies.node import Node, NodeID + from flight.strategies import NodeState, Params + + +class FedSGDCoord(DefaultCoordStrategy): + """The coordinator and its respective methods for 'FedSGD'.""" + + def __init__( + self, + participation, + probabilistic, + always_include_child_aggregators: bool, + ): + self.participation = participation + self.probabilistic = probabilistic + self.always_include_child_aggregators = always_include_child_aggregators + + def select_worker_nodes( + self, state: NodeState, workers: t.Iterable[Node], rng: Generator | None = None + ) -> t.Sequence[Node]: + """Method containing the method for worker selection for 'FedSGD'. + + Args: + state (NodeState): State of the coordinator node. + workers (t.Iterable[Node]): Iterable containing the worker nodes. + rng (Generator | None, optional): RNG object used for randomness. Defaults to None. + + Returns: + t.Sequence[Node]: The selected worker nodes. + """ + selected_workers = random_worker_selection( + workers, + participation=self.participation, + probabilistic=self.probabilistic, + always_include_child_aggregators=self.always_include_child_aggregators, + rng=rng, + ) + return selected_workers + + +class FedSGDAggr(DefaultAggrStrategy): + """The aggregator and its respective methods for 'FedSGD'. + + Args: + DefaultAggrStrategy: The base class providing the necessary methods for 'FedSGDAggr'. + """ + + def aggregate_params( + self, + state: NodeState, + children_states: t.Mapping[NodeID, NodeState], + children_state_dicts: t.Mapping[NodeID, Params], + **kwargs, + ) -> Params: + """Method used by aggregator nodes for aggregating the passed node state dictionary. + + Args: + state (NodeState): State of the current aggregator node. + children_states (t.Mapping[NodeID, NodeState]): Dictionary of the states of the children. + children_state_dicts (t.Mapping[NodeID, Params]): Dictionary mapping each child to its values. + **kwargs: Key word arguments provided by the user. + + Returns: + Params: The aggregated values. + """ + return average_state_dicts(children_state_dicts, weights=None) + + +class FedSGD(Strategy): + """ + Implementation of the FedSGD strategy, which uses 'FedSGD' for the coordinator and aggregators, and defaults + for the workers and trainer. + """ + + def __init__( + self, + participation: float = 1.0, + probabilistic: bool = False, + always_include_child_aggregators: bool = True, + ): + super().__init__( + coord_strategy=FedSGDCoord( + participation, + probabilistic, + always_include_child_aggregators, + ), + aggr_strategy=FedSGDAggr(), + worker_strategy=DefaultWorkerStrategy(), + trainer_strategy=DefaultTrainerStrategy(), + ) diff --git a/flight/strategies/trainer.py b/flight/strategies/trainer.py index c0fd3cf..236b571 100644 --- a/flight/strategies/trainer.py +++ b/flight/strategies/trainer.py @@ -3,14 +3,34 @@ import typing as t if t.TYPE_CHECKING: - import torch - Loss: t.TypeAlias = torch.Tensor + from flight.strategies import Loss, NodeState +@t.runtime_checkable class TrainerStrategy(t.Protocol): - def before_backprop(self, state, loss: Loss) -> Loss: + """Template for all trainer strategies, including those defined in Flight and those defined by users.""" + + def before_backprop(self, state: NodeState, loss: Loss) -> Loss: + """Callback to run before backpropagation. + + Args: + state (NodeState): State of the current node. + loss (Loss): The calculated loss + + Returns: + The loss at the end of the callback + """ pass - def after_backprop(self, state, loss: Loss) -> Loss: + def after_backprop(self, state: NodeState, loss: Loss) -> Loss: + """Callback to run after backpropagation. + + Args: + state (NodeState): State of the current node. + loss (Loss): The calculated loss + + Returns: + The loss at the end of the callback + """ pass diff --git a/flight/strategies/worker.py b/flight/strategies/worker.py index c3b6853..4d0f2ac 100644 --- a/flight/strategies/worker.py +++ b/flight/strategies/worker.py @@ -3,18 +3,61 @@ import typing as t if t.TYPE_CHECKING: - pass + import torch + from flight.federation.jobs.result import Result + from flight.strategies import NodeState + +@t.runtime_checkable class WorkerStrategy(t.Protocol): - def start_work(self): + """Template for all aggregator strategies, including those defined in Flight and those defined by Users.""" + + def start_work(self, state: NodeState) -> NodeState: + """Callback to run at the start of the current nodes 'work'. + + Args: + state (NodeState): State of the current worker node. + + Returns: + The state of the current node at the end of the callback. + """ pass - def before_training(self): + def before_training(self, state: NodeState, data: t.Any) -> tuple[NodeState, t.Any]: + """Callback to run before the current nodes training. + + Args: + state (NodeState): State of the current worker node. + data (t.Any): The data related to the current worker node. + + Returns: + A tuple containing the state and data of the current worker node after the callback. + """ pass - def after_training(self): + def after_training( + self, state: NodeState, optimizer: torch.optim.Optimizer + ) -> NodeState: + """Callback to run after the current nodes training. + + Args: + state (NodeState): State of the current worker node. + optimizer (torch.optim.Optimizer): The PyTorch optimizer used. + + Returns: + The state of the current worker node after the callback. + """ pass - def end_work(self): + def end_work(self, result: Result) -> Result: + """Callback to run at the end of the current worker nodes 'work' + + Args: + result (Result): A Result object used to represent the result of the local training on the current worker node. + + Returns: + The result of the worker nodes local training. + + """ pass diff --git a/tests/strategies/__init__.py b/tests/strategies/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/strategies/environment.py b/tests/strategies/environment.py new file mode 100644 index 0000000..d91a000 --- /dev/null +++ b/tests/strategies/environment.py @@ -0,0 +1,18 @@ +from flight.federation.topologies.node import Node, NodeKind + + +def create_children(num_workers: int, num_aggrs: int = 0) -> list[Node]: + """Creates a fabricated list of children used for coordinator/selecting workers test cases. + + Args: + num_workers (int): Number of workers to be added. + num_aggrs (int, optional): Number of aggregators to be added. Defaults to 0. + + Returns: + list[Node]: A list of the created children. + """ + aggr = [Node(idx=i, kind=NodeKind.AGGR) for i in range(1, num_aggrs + 1)] + workers = [ + Node(idx=i + num_aggrs, kind=NodeKind.WORKER) for i in range(1, num_workers + 1) + ] + return workers + aggr diff --git a/tests/strategies/impl/__init__.py b/tests/strategies/impl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/strategies/impl/test_fedasync.py b/tests/strategies/impl/test_fedasync.py new file mode 100644 index 0000000..2173843 --- /dev/null +++ b/tests/strategies/impl/test_fedasync.py @@ -0,0 +1,67 @@ +import typing as t + +import pytest +import torch + +from flight.strategies import ( + AggrStrategy, + CoordStrategy, + DefaultStrategy, + NodeState, + TrainerStrategy, + WorkerStrategy, +) +from flight.strategies.base import ( + DefaultCoordStrategy, + DefaultTrainerStrategy, + DefaultWorkerStrategy, +) +from flight.strategies.impl.fedasync import FedAsync, FedAsyncAggr + + +class TestValidFedAsync: + def test_class_hierarchy(self): + """Test that the associated node strategy types follow the correct protocols.""" + fedasync = FedAsync(0.5) + + assert isinstance(fedasync.aggr_strategy, (AggrStrategy, FedAsyncAggr)) + assert isinstance( + fedasync.coord_strategy, (CoordStrategy, DefaultCoordStrategy) + ) + assert isinstance( + fedasync.trainer_strategy, (TrainerStrategy, DefaultTrainerStrategy) + ) + assert isinstance( + fedasync.worker_strategy, (WorkerStrategy, DefaultWorkerStrategy) + ) + + def test_fedasync_aggr(self): + """Tests implementation of the aggregator within 'FedAsync'""" + strategy = FedAsync(0.5) + aggr_strategy: AggrStrategy = strategy.aggr_strategy + + nodestate: NodeState = {} + childstates = {1: "foo1", 2: "foo2"} + children_state_dicts_pt = { + 1: { + "train/loss": torch.tensor(2.3, dtype=torch.float32), + "train/acc": torch.tensor(1.2, dtype=torch.float32), + }, + 2: { + "train/loss": torch.tensor(3.1, dtype=torch.float32), + "train/acc": torch.tensor(1.4, dtype=torch.float32), + }, + } + + # avg = aggr_strategy.aggregate_params( + # nodestate, childstates, children_state_dicts_pt, last_updated_node=1 + # ) + + assert NotImplementedError + + +class TestInvalidFedAsync: + def test_invalid_alpha(self): + """Test inputing a value for alpha which is too large.""" + with pytest.raises(AssertionError): + fedasync = FedAsync(alpha=1.1) diff --git a/tests/strategies/impl/test_fedavg.py b/tests/strategies/impl/test_fedavg.py new file mode 100644 index 0000000..b3c41a3 --- /dev/null +++ b/tests/strategies/impl/test_fedavg.py @@ -0,0 +1,86 @@ +import typing as t + +import torch + +from flight.strategies import ( + AggrStrategy, + CoordStrategy, + Params, + TrainerStrategy, + WorkerStrategy, +) +from flight.strategies.base import DefaultTrainerStrategy +from flight.strategies.impl.fedavg import FedAvg, FedAvgAggr, FedAvgWorker +from flight.strategies.impl.fedsgd import FedSGDCoord + +if t.TYPE_CHECKING: + NodeState: t.TypeAlias = t.Any + + +class TestValidFedAvg: + def test_fedavg_class_hierarchy(self): + """Test that the associated node strategy types follow the correct protocols.""" + fedavg = FedAvg() + + assert isinstance(fedavg.aggr_strategy, (AggrStrategy, FedAvgAggr)) + assert isinstance(fedavg.coord_strategy, (CoordStrategy, FedSGDCoord)) + assert isinstance( + fedavg.trainer_strategy, (TrainerStrategy, DefaultTrainerStrategy) + ) + assert isinstance(fedavg.worker_strategy, (WorkerStrategy, FedAvgWorker)) + + def test_fedavg_aggr(self): + """Tests the usability of the aggregator strategy for 'FedAvg'""" + fedavg = FedAvg() + aggregatorStrat: AggrStrategy = fedavg.aggr_strategy + nodestate: NodeState = {} + childstates = { + 1: {"num_data_samples": 1, "other_data": "foo"}, + 2: {"num_data_samples": 1, "other_data": "foo"}, + } + children_state_dicts_pt = { + 1: { + "train/loss": torch.tensor(2.3, dtype=torch.float32), + "train/acc": torch.tensor(1.2, dtype=torch.float32), + }, + 2: { + "train/loss": torch.tensor(3.1, dtype=torch.float32), + "train/acc": torch.tensor(1.4, dtype=torch.float32), + }, + } + + aggregated = aggregatorStrat.aggregate_params( + nodestate, childstates, children_state_dicts_pt + ) + + assert isinstance(aggregated, dict) + + expected_avg = { + "train/loss": 2.7, + "train/acc": 1.3, + } + + epsilon = 1e-6 + for key, value in aggregated.items(): + expected = expected_avg[key] + assert abs(expected - value.item()) < epsilon + + def test_fedavg_worker(self): + """Tests the usability of the worker strategy for 'FedAvg'""" + fedavg = FedAvg() + + workerStrat: WorkerStrategy = fedavg.worker_strategy + + nodestate_before: NodeState = {"State:": "Training preperation"} + data_before: Params = { + "train/loss1": torch.tensor(0.35, dtype=torch.float32), + "train/loss2": torch.tensor(0.5, dtype=torch.float32), + "train/loss3": torch.tensor(0.23, dtype=torch.float32), + } + + nodestate_after, data_after = workerStrat.before_training( + nodestate_before, data_before + ) + + assert nodestate_after["num_data_samples"] == len(data_before) + assert data_before == data_after diff --git a/tests/strategies/impl/test_fedprox.py b/tests/strategies/impl/test_fedprox.py new file mode 100644 index 0000000..7b92e3d --- /dev/null +++ b/tests/strategies/impl/test_fedprox.py @@ -0,0 +1,19 @@ +from flight.strategies import ( + AggrStrategy, + CoordStrategy, + TrainerStrategy, + WorkerStrategy, +) +from flight.strategies.impl.fedavg import FedAvgWorker +from flight.strategies.impl.fedprox import FedProx, FedProxTrainer +from flight.strategies.impl.fedsgd import FedSGDAggr, FedSGDCoord + + +class TestValidFedProx: + def test_fedprox_class_hierarchy(self): + """Test that the associated node strategy types follow the correct protocols.""" + fedprox = FedProx(0.3, 1, False, False) + assert isinstance(fedprox.aggr_strategy, (AggrStrategy, FedSGDAggr)) + assert isinstance(fedprox.coord_strategy, (CoordStrategy, FedSGDCoord)) + assert isinstance(fedprox.trainer_strategy, (TrainerStrategy, FedProxTrainer)) + assert isinstance(fedprox.worker_strategy, (WorkerStrategy, FedAvgWorker)) diff --git a/tests/strategies/impl/test_fedsgd.py b/tests/strategies/impl/test_fedsgd.py new file mode 100644 index 0000000..a42e209 --- /dev/null +++ b/tests/strategies/impl/test_fedsgd.py @@ -0,0 +1,72 @@ +import torch +from numpy.random import default_rng + +from flight.strategies import ( + AggrStrategy, + CoordStrategy, + TrainerStrategy, + WorkerStrategy, +) +from flight.strategies.base import DefaultTrainerStrategy, DefaultWorkerStrategy +from flight.strategies.impl.fedsgd import FedSGD, FedSGDAggr, FedSGDCoord +from tests.strategies.environment import create_children + + +class TestValidFedSGD: + def test_fedsgd_class_hierarchy(self): + """Test that the associated node strategy types follow the correct protocols.""" + fedsgd = FedSGD(1, False, True) + + assert isinstance(fedsgd.aggr_strategy, (AggrStrategy, FedSGDAggr)) + assert isinstance(fedsgd.coord_strategy, (CoordStrategy, FedSGDCoord)) + assert isinstance( + fedsgd.trainer_strategy, (TrainerStrategy, DefaultTrainerStrategy) + ) + assert isinstance( + fedsgd.worker_strategy, (WorkerStrategy, DefaultWorkerStrategy) + ) + + def test_default_fedsgd_coord(self): + """Tests the usability of the coordinator strategy for 'FedSGD'""" + fedsgd = FedSGD(1, False, True) + coordStrat: CoordStrategy = fedsgd.coord_strategy + gen = default_rng() + workers = create_children(numWorkers=10) + + selected = coordStrat.select_workers("foo", workers, gen) + + for worker in workers: + assert worker in selected + + def test_fedsgd_aggr(self): + """Tests the usability of the aggregator strategy for 'FedSGD'""" + fedsgd = FedSGD(1, False, True) + aggrStrat: AggrStrategy = fedsgd.aggr_strategy + + state = "foo" + children = {1: "foo1", 2: "foo2"} + + children_state_dicts_pt = { + 1: { + "train/loss": torch.tensor(2.3, dtype=torch.float32), + "train/acc": torch.tensor(1.2, dtype=torch.float32), + }, + 2: { + "train/loss": torch.tensor(3.1, dtype=torch.float32), + "train/acc": torch.tensor(1.4, dtype=torch.float32), + }, + } + + avg = aggrStrat.aggregate_params(state, children, children_state_dicts_pt) + + assert isinstance(avg, dict) + + expected_avg = { + "train/loss": 2.7, + "train/acc": 1.3, + } + + epsilon = 1e-6 + for key, value in avg.items(): + expected = expected_avg[key] + assert abs(expected - value.item()) < epsilon diff --git a/tests/strategies/test_aggr.py b/tests/strategies/test_aggr.py new file mode 100644 index 0000000..7627a4a --- /dev/null +++ b/tests/strategies/test_aggr.py @@ -0,0 +1,46 @@ +import typing as t + +import torch + +from flight.strategies import AggrStrategy, NodeState +from flight.strategies.base import DefaultAggrStrategy + + +def test_instance(): + """Test that the associated node strategy type follows the correct protocols.""" + default_aggr = DefaultAggrStrategy() + + assert isinstance(default_aggr, AggrStrategy) + + +def test_aggr_aggregate_params(): + """Tests usability for the 'aggregate_params' function on two children.""" + default_aggr = DefaultAggrStrategy() + + state: NodeState = "foo" + children = {1: "foo1", 2: "foo2"} + + children_state_dicts_pt = { + 1: { + "train/loss": torch.tensor(2.3, dtype=torch.float32), + "train/acc": torch.tensor(1.2, dtype=torch.float32), + }, + 2: { + "train/loss": torch.tensor(3.1, dtype=torch.float32), + "train/acc": torch.tensor(1.4, dtype=torch.float32), + }, + } + + avg = default_aggr.aggregate_params(state, children, children_state_dicts_pt) + + assert isinstance(avg, dict) + + expected_avg = { + "train/loss": 2.7, + "train/acc": 1.3, + } + + epsilon = 1e-6 + for key, value in avg.items(): + expected = expected_avg[key] + assert abs(expected - value.item()) < epsilon diff --git a/tests/strategies/test_coord.py b/tests/strategies/test_coord.py new file mode 100644 index 0000000..9c02591 --- /dev/null +++ b/tests/strategies/test_coord.py @@ -0,0 +1,72 @@ +import pytest +from numpy.random import default_rng + +from flight.strategies import CoordStrategy +from flight.strategies.base import DefaultCoordStrategy +from flight.strategies.commons.worker_selection import random_worker_selection +from tests.strategies.environment import create_children + + +def test_instance(): + """Test that the associated node strategy type follows the correct protocols.""" + default_coord = DefaultCoordStrategy() + + assert isinstance(default_coord, CoordStrategy) + + +def test_worker_selection(): + """Tests both fix and probabilistic worker selection on five workers.""" + gen = default_rng() + for _ in range(5): + children = create_children(numWorkers=5) + # fixed random + fixed_random = random_worker_selection( + children, + participation=1, + probabilistic=False, + always_include_child_aggregators=True, + rng=gen, + ) + # prob random + prob_random = random_worker_selection( + children, + participation=1, + probabilistic=True, + always_include_child_aggregators=True, + rng=gen, + ) + + for child in children: + assert child in fixed_random and child in prob_random + + +class TestInvalidFixedSelection: + def test_fixed_random(self): + """Tests an invalid level of participation on fixed selection raises a 'ValueError'""" + gen = default_rng() + + children = create_children(numWorkers=5) + + with pytest.raises(ValueError): + fixed_random = random_worker_selection( + children, + participation=2, + probabilistic=False, + always_include_child_aggregators=True, + rng=gen, + ) + + def test_fixed_random_mix(self): + """Tests an invalid level of participation on prob selection raises a 'ValueError'""" + gen = default_rng() + + children = create_children(num_workers=1, num_aggrs=2) + + with pytest.raises(ValueError): + fixed_random = random_worker_selection( + children, + participation=2, + probabilistic=False, + always_include_child_aggregators=True, + rng=gen, + ) diff --git a/tests/strategies/test_worker.py b/tests/strategies/test_worker.py new file mode 100644 index 0000000..8c31ce1 --- /dev/null +++ b/tests/strategies/test_worker.py @@ -0,0 +1,17 @@ +import typing as t + +import torch + +from flight.federation.jobs.result import Result +from flight.strategies import WorkerStrategy +from flight.strategies.base import DefaultWorkerStrategy + +if t.TYPE_CHECKING: + NodeState: t.TypeAlias = t.Any + + +def test_instance(): + """Test that the associated node strategy type follows the correct protocols.""" + default_worker = DefaultWorkerStrategy() + + assert isinstance(default_worker, WorkerStrategy)