Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flight refactor strategies #40

Merged
merged 14 commits into from
Jul 19, 2024
10 changes: 5 additions & 5 deletions flight/federation/jobs/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from pydantic.dataclasses import dataclass

from flight.learning.module import RecordList
from flight.strategies.aggr import Params
from flight.strategies import Params

if t.TYPE_CHECKING:
NodeID: t.TypeAlias = t.Hashable
NodeState: t.TypeAlias = tuple
NodeID: t.TypeAlias = t.Hashable
NodeState: t.TypeAlias = tuple


@dataclass
# TODO: Remove config when all type definitions have been resolvedß
@dataclass(config={"arbitrary_types_allowed": True})
class Result:
state: NodeState
node_idx: NodeID
Expand Down
1 change: 1 addition & 0 deletions flight/federation/topologies/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
wish to use `from_yaml()` to create a Topology with a YAML file, we encourage the use the `Topology.from_yaml()` method
instead.
"""

from __future__ import annotations

import json
Expand Down
62 changes: 62 additions & 0 deletions flight/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import typing as t

import torch

from flight.strategies.aggr import AggrStrategy
from flight.strategies.base import DefaultStrategy, Strategy
from flight.strategies.coord import CoordStrategy
from flight.strategies.trainer import TrainerStrategy
from flight.strategies.worker import WorkerStrategy

Loss: t.TypeAlias = torch.Tensor
Params: t.TypeAlias = dict[str, torch.Tensor]
NodeState: t.TypeAlias = t.Any


def load_strategy(strategy_name: str, **kwargs) -> Strategy:
"""Function used to grab the users preferred 'Strategy'.

Args:
strategy_name (str): The name of the 'Strategy' to be grabbed.

Raises:
ValueError: If an unknown 'Strategy' type is passed through.

Returns:
Strategy: The selected 'Strategy' type.
"""
assert isinstance(strategy_name, str), "`strategy_name` must be a string."
match strategy_name.lower():
case "default":
return DefaultStrategy()

case "fedasync" | "fed-async":
from flight.strategies.impl.fedasync import FedAsync

return FedAsync(**kwargs)

case "fedavg" | "fed-avg":
from flight.strategies.impl.fedavg import FedAvg

return FedAvg(**kwargs)

case "fedprox" | "fed-prox":
from flight.strategies.impl.fedprox import FedProx

return FedProx(**kwargs)

case "fedsgd" | "fed-sgd":
from flight.strategies.impl.fedsgd import FedSGD

return FedSGD(**kwargs)
case _:
raise ValueError(f"Strategy '{strategy_name}' is not recognized.")


__all__ = [
"AggrStrategy",
"Strategy",
"CoordStrategy",
"TrainerStrategy",
"WorkerStrategy",
]
29 changes: 27 additions & 2 deletions flight/strategies/aggr.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,40 @@
from __future__ import annotations

import typing as t

if t.TYPE_CHECKING:
Params: t.TypeAlias = t.Any
from flight.federation.topologies.node import NodeID
from flight.strategies import NodeState, Params


@t.runtime_checkable
class AggrStrategy(t.Protocol):
"""Template for all aggregator strategies, including those defined in Flight and those defined by Users."""

def start_round(self):
"""Callback to run at the start of a round."""
pass

def aggregate_params(self) -> Params:
def aggregate_params(
self,
state: NodeState,
children_states: t.Mapping[NodeID, NodeState],
children_state_dicts: t.Mapping[NodeID, Params],
**kwargs,
) -> Params:
"""Callback that handles the model parameter aggregation step.

Args:
state (NodeState): The state of the current aggregator node.
children_states (t.Mapping[NodeID, NodeState]): A mapping of the current aggregator node's children and their respective states.
children_state_dicts (t.Mapping[NodeID, Parmas]): The model parameters of the models to each respective child node.
**kwargs: Keyword arguments provided by users.

Returns:
Params: The aggregated parameters to update the model at the current aggregator.
"""
pass

def end_round(self):
"""Callback to run at the end of a round."""
pass
256 changes: 221 additions & 35 deletions flight/strategies/base.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,221 @@
import functools
import typing as t

import pydantic as pyd


@pyd.dataclasses.dataclass(frozen=True, repr=False)
class Strategy:
coord_strategy: str = pyd.field()
aggr_strategy: str = pyd.field()
worker_strategy: str = pyd.field()
trainer_strategy: str = pyd.field()

def __iter__(self) -> t.Iterator[tuple[str, t.Any]]:
yield from (
("coord_strategy", self.coord_strategy),
("aggr_strategy", self.aggr_strategy),
("worker_strategy", self.worker_strategy),
("trainer_strategy", self.trainer_strategy),
)

def __repr__(self) -> str:
return str(self)

@functools.cached_property
def __str__(self) -> str:
name = self.__class__.__name__
inner = ", ".join(
[
f"{strategy_key}={strategy_value.__class__.__name__}"
for (strategy_key, strategy_value) in iter(self)
if strategy_value is not None
]
)
return f"{name}({inner})"
from __future__ import annotations

import functools
import typing as t

import pydantic as pyd

from flight.strategies.aggr import AggrStrategy
from flight.strategies.commons.averaging import average_state_dicts
from flight.strategies.coord import CoordStrategy
from flight.strategies.trainer import TrainerStrategy
from flight.strategies.worker import WorkerStrategy

StrategyType: t.TypeAlias = (
WorkerStrategy | AggrStrategy | CoordStrategy | TrainerStrategy
)

if t.TYPE_CHECKING:
import torch
from numpy.random import Generator

from flight.federation.jobs.result import Result
from flight.federation.topologies.node import Node, NodeID
from flight.strategies import Loss, NodeState, Params


class DefaultCoordStrategy:
"""Default implementation of the strategy for a coordinator."""

def select_workers(
self, state: NodeState, workers: t.Iterable[Node], rng: Generator
) -> t.Sequence[Node]:
"""Method used for the selection of workers.

Args:
state (NodeState): The state of the coordinator node.
workers (t.Iterable[Node]): Iterable object containing all of the worker nodes.
rng (Generator): RNG object used for randomness.

Returns:
t.Sequence[Node]: The selected workers.
"""
return list(workers)


class DefaultAggrStrategy:
"""Default implementation of the strategy for an aggregator."""

def start_round(self):
pass

def aggregate_params(
self,
state: NodeState,
children_states: t.Mapping[NodeID, NodeState],
children_state_dicts: t.Mapping[NodeID, Params],
**kwargs,
) -> Params:
"""Callback that handles the model parameter aggregation step.

Args:
state (NodeState): The state of the current aggregator node.
children_states (t.Mapping[NodeID, NodeState]): A mapping of the current aggregator node's children and their respective states.
children_state_dicts (t.Mapping[NodeID, Parmas]): The model parameters of the models to each respective child node.
**kwargs: Keyword arguments provided by users.

Returns:
Params: The aggregated values.
"""
return average_state_dicts(children_state_dicts, weights=None)

def end_round(self):
pass


class DefaultWorkerStrategy:
"""Default implementation of the strategy for a worker"""

def start_work(self, state: NodeState) -> NodeState:
"""Callback to be ran and the start of the current worker nodes work.

Args:
state (NodeState): The state of the current worker node.

Returns:
NodeState: The state of the current worker node at the end of the callback.
"""
return state

def before_training(
self, state: NodeState, data: Params
) -> tuple[NodeState, Params]:
"""Callback to be ran before training.

Args:
state (NodeState): The state of the current worker node.
data (Params): The data associated with the current worker node.

Returns:
tuple[NodeState, Params]: A tuple containing the state and data of the worker node at the end of the callback.
"""
return state, data

def after_training(
self, state: NodeState, optimizer: torch.optim.Optimizer
) -> NodeState:
"""Callback to be ran after training.

Args:
state (NodeState): The state of the current worker node.
optimizer (torch.optim.Optimizer): The PyTorch optimier to be used.

Returns:
NodeState: The state of the worker node at the end of the callback.
"""
return state

def end_work(self, result: Result) -> Result:
"""Callback to be ran at the end of the work.

Args:
result (Result): A Result object used to represent the result of the local training on the current worker node.

Returns:
Result: The result of the worker nodes local training.
"""
return result


class DefaultTrainerStrategy:
"""Default implementation of a strategy for the trainer."""

def before_backprop(self, state: NodeState, loss: Loss) -> Loss:
"""Callback to run before backpropagation.

Args:
state (NodeState): State of the current node.
loss (Loss): The calculated loss

Returns:
The loss at the end of the callback
"""
return loss

def after_backprop(self, state: NodeState, loss: Loss) -> Loss:
"""Callback to run after backpropagation.

Args:
state (NodeState): State of the current node.
loss (Loss): The calculated loss

Returns:
The loss at the end of the callback
"""
return loss


# TODO: Remove config when all type definitions have been resolved
@pyd.dataclasses.dataclass(
frozen=True, repr=False, config={"arbitrary_types_allowed": True}
)
class Strategy:
"""
A 'Strategy' implementation is comprised of the four different type of implementations of strategies
to be used on the respective node types throughout the training process.
"""

"""Implementation of the specific callbacks for the coordinator node."""
coord_strategy: CoordStrategy = pyd.Field()
"""Implementation of the specific callbacks for the aggregator node(s)."""
aggr_strategy: AggrStrategy = pyd.Field()
"""Implementation of the specific callbacks for the worker node(s)."""
worker_strategy: WorkerStrategy = pyd.Field()
"""Implementation of callbacks specific to the execution of the training loop on the worker node(s)."""
trainer_strategy: TrainerStrategy = pyd.Field()

def __iter__(self) -> t.Iterator[tuple[str, StrategyType]]:
yield from (
("coord_strategy", self.coord_strategy),
("aggr_strategy", self.aggr_strategy),
("worker_strategy", self.worker_strategy),
("trainer_strategy", self.trainer_strategy),
)

def __repr__(self) -> str:
return str(self)

@functools.cached_property
def _description(self) -> str:
"""A utility function for generating the string for `__str__`.

This is written to avoid the `mypy` issue:
"Signature of '__str__' incompatible with supertype 'object'".

Returns:
The string representation of the a Strategy instance.
"""
name = self.__class__.__name__
inner = ", ".join(
[
f"{strategy_key}={strategy_value.__class__.__name__}"
for (strategy_key, strategy_value) in iter(self)
if strategy_value is not None
]
)
return f"{name}({inner})"

def __str__(self) -> str:
return self._description


class DefaultStrategy(Strategy):
"""Implementation of a strategy that uses the default strategy types for each node type."""

def __init__(self) -> None:
super().__init__(
coord_strategy=DefaultCoordStrategy(),
aggr_strategy=DefaultAggrStrategy(),
worker_strategy=DefaultWorkerStrategy(),
trainer_strategy=DefaultTrainerStrategy(),
)
4 changes: 4 additions & 0 deletions flight/strategies/commons/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from flight.strategies.commons.averaging import average_state_dicts
from flight.strategies.commons.worker_selection import random_worker_selection

__all__ = ["average_state_dicts", "random_worker_selection"]
Loading
Loading