Skip to content

Commit

Permalink
Start of data modules and trainer implementations
Browse files Browse the repository at this point in the history
  • Loading branch information
nathaniel-hudson committed Jul 31, 2024
1 parent a147332 commit e0b8ba1
Show file tree
Hide file tree
Showing 47 changed files with 1,065 additions and 275 deletions.
7 changes: 5 additions & 2 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ per-file-ignores =

# TODO: Change to 88 later for black
max-line-length = 120
exclude:
quickstart/
exclude =
quickstart
flox
flox_examples
flox_tests


6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,9 @@ repos:
hooks:
- id: flake8
additional_dependencies: ['flake8-bugbear==22.10.27']

exclude:
flox/,
flox_examples/,
flox_tests/,
quickstart/
7 changes: 5 additions & 2 deletions flight/federation/fed_abs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import abc
import typing as t
from concurrent.futures import Future
Expand All @@ -6,18 +8,19 @@
from flight.strategies.coord import CoordStrategy
from flight.strategies.trainer import TrainerStrategy
from flight.strategies.worker import WorkerStrategy

from ..learning.datasets.loadable import DataLoadable
from ..types import Record
from .jobs.types import Result, TrainJob, TrainJobArgs
from .jobs.work import default_training_job
from .topologies.node import Node
from .topologies.topo import Topology
from ..learning.datasets import DataLoadable

if t.TYPE_CHECKING:
from .fed_sync import Engine

Strategy: t.TypeAlias = t.Any
Module: t.TypeAlias = t.Any
Record: t.TypeAlias = dict[str, t.Any]


class Federation(abc.ABC):
Expand Down
16 changes: 15 additions & 1 deletion flight/federation/fed_async.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
import typing as t

from .fed_abs import Federation
from .topologies.node import Node, NodeKind


class AsyncFederation(Federation):
pass
def __init__(self):
pass

def start_aggregator_task(
self,
node: Node,
selected_children: t.Sequence[Node],
) -> Future[Result]:
raise NotImplementedError(
"This method is not implemented. Async federations only support 2-tier topologies "
"(i.e., there are no intermediate aggregators)."
)
11 changes: 6 additions & 5 deletions flight/federation/fed_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,16 @@
import typing as t
from concurrent.futures import Future

from flox.federation.topologies import Node, NodeKind
from .fed_abs import Federation
from .topologies.topo import Topology
from flight.learning.modules.base.module import Trainable

from ..learning.datasets import DataLoadable
from ..learning.module import Trainable
from ..strategies.base import Strategy
from .fed_abs import Federation
from .topologies.node import Node, NodeKind
from .topologies.topo import Topology

if t.TYPE_CHECKING:
from .jobs.types import Result, AggrJobArgs
from .jobs.types import AggrJobArgs, Result

Engine: t.TypeAlias = t.Any

Expand Down
16 changes: 8 additions & 8 deletions flight/federation/jobs/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,22 @@
from proxystore.proxy import Proxy
from pydantic.dataclasses import dataclass

from flight.learning.module import RecordList
from flight.strategies import Params
from flight.federation.topologies.node import NodeID, NodeState
from flight.learning.types import Params
from flight.types import Record

NodeID: t.TypeAlias = t.Hashable
NodeState: t.TypeAlias = tuple


# TODO: Remove config when all type definitions have been resolvedß
# TODO: Remove config when all type definitions have been resolved
@dataclass(config={"arbitrary_types_allowed": True})
class Result:
state: NodeState
node_idx: NodeID
params: Params
records: RecordList
records: list[Record]
cache: dict[str, t.Any]


AbstractResult: t.TypeAlias = Result | Proxy[Result]
"""Helper type alias for a `Result` or a proxy to a `Result`."""
"""
Helper type alias for a `Result` or a proxy to a `Result`.
"""
44 changes: 24 additions & 20 deletions flight/federation/jobs/types.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
from __future__ import annotations

import typing as t
from concurrent.futures import Future
from dataclasses import dataclass

import pydantic as pyd
from dataclasses import dataclass, field

from flight.federation.topologies.node import Node
from flight.learning.module import RecordList
from flight.federation.topologies.node import Node, NodeState, WorkerState
from flight.learning.datasets.loadable import DataLoadable
from flight.learning.modules.base import Record
from flight.learning.modules.torch import FlightModule
from flight.learning.types import Params

if t.TYPE_CHECKING:
from flight.learning.datasets import DataLoadable
from flight.learning.module import FlightModule
from flight.strategies.trainer import TrainerStrategy
from flight.strategies.worker import WorkerStrategy

NodeState: t.TypeAlias = t.Any
Params: t.TypeAlias = t.Any


@pyd.dataclasses.dataclass
class Result(pyd.BaseModel):
node: Node = pyd.Field()
node_state: NodeState = pyd.Field()
params: Params = pyd.Field()
records: RecordList = pyd.Field()
cache: dict[str, t.Any] = pyd.Field(default_factory=dict, init=False)


AggrJob: t.TypeAlias = t.Callable[[Node, Node], Result]
@dataclass
class Result:
node: Node
"""The node that produced this result during a federation."""
node_state: NodeState
"""The current state of the node that returned a given result during a federation."""
params: Params
"""Parameters returned as part of a result from a single Node in a federation."""
records: list[Record] = field(default_factory=list)
"""List of records for model training/aggregation metrics."""
extra: dict[str, t.Any] = field(default_factory=dict)
"""Extra data recorded by a node during the runtime of its job."""


# class TrainJob(t.Protocol):
Expand Down Expand Up @@ -57,11 +57,15 @@ class TrainJobArgs:

node: Node
parent: Node
node_state: WorkerState
model: FlightModule
data: DataLoadable
worker_strategy: WorkerStrategy
trainer_strategy: TrainerStrategy


AggrJob: t.TypeAlias = t.Callable[[AggrJobArgs], Result]
"""Function signature for aggregation jobs."""

TrainJob: t.TypeAlias = t.Callable[[TrainJobArgs], Result]
"""Function signature for loca training jobs."""
37 changes: 25 additions & 12 deletions flight/federation/jobs/work.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import typing as t
from __future__ import annotations

import typing as t

if t.TYPE_CHECKING:
from flight.federation.jobs.types import Result, TrainJobArgs
Expand All @@ -11,40 +12,52 @@ def default_training_job(args: TrainJobArgs) -> Result:

from torch.utils.data import DataLoader

hparams = trainer_strategy.trainer_hparams()
from flight.learning.trainers.torch import TorchTrainer

hparams = args.trainer_strategy.trainer_hparams()

training_start = datetime.now()

state = worker_strategy.start_work()
state = args.worker_strategy.start_work()

data = {
"train": data.load(node, "train"),
"valid": data.load(node, "valid"),
"train": args.data.load(args.node, "train"),
"valid": args.data.load(args.node, "valid"),
}

train_dataloader = DataLoader(
data["train"],
**{key: val for (key, val) in hparams if key.startswith("dataloader.train.")},
)

trainer = Trainer(trainer_strategy)
trainer = TorchTrainer(args.trainer_strategy)
local_model = args.model.copy()
optimizer = args.model.configure_optimizers()
trainer.fit(
args.node_state,
local_model,
optimizer,
train_dataloader,
node_state,
**{key: val for (key, val) in hparams if key.startswith("trainer.")},
)

state = worker_strategy.end_work()
state = args.worker_strategy.end_work()

training_end = datetime.now()

history = {
"node_idx": node.idx,
"node_kind": node.kind,
"parent_idx": parent.idx,
"parent_kind": parent.kind,
"node_idx": args.node.idx,
"node_kind": args.node.kind,
"parent_idx": args.parent.idx,
"parent_kind": args.parent.kind,
"training_start": training_start,
"training_end": training_end,
}

return Result(
node=...,
node_state=...,
params=...,
records=...,
extra=...,
)
3 changes: 3 additions & 0 deletions flight/federation/topologies/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .node import Node

__all__ = ["Node"]
89 changes: 81 additions & 8 deletions flight/federation/topologies/node.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import typing as t
from dataclasses import dataclass, field
from enum import Enum
from uuid import UUID

import pydantic as pyd

NodeID: t.TypeAlias = int | str
"""ID of nodes in Flight topologies; can either be of type `int` or `str`."""
from flight.learning.modules.base import Trainable

NodeID: t.TypeAlias = t.Union[int, str]
"""
ID of nodes in Flight topologies; can either be of type `int` or `str`.
"""


class NodeKind(str, Enum):
Expand All @@ -27,12 +32,80 @@ class Node(pyd.BaseModel):
[`Topology`][flight.federation.topologies.topo.Topology] class."""

idx: NodeID
"""The ID of the node."""
"""
The ID of the node.
"""

kind: NodeKind
"""The kind of Node---indicates its *role* in a federation."""
"""
The kind of Node---indicates its *role* in a federation.
"""

globus_comp_id: UUID | None = pyd.Field(default=None)
"""Globus Compute UUID for remote execution."""
"""
Globus Compute UUID for remote execution.
"""

proxystore_id: UUID | None = pyd.Field(default=None)
"""ProxyStore UUID for data transfer for remote execution with Globus Compute."""
extra: dict[str, t.Any] | None = pyd.Field(default=None)
"""Any extra parameters users wish to give to Nodes (e.g., parameters or settings around system resource use)."""
"""
ProxyStore UUID for data transfer for remote execution with Globus Compute.
"""

extra: dict[str, t.Any] = pyd.Field(default_factory=dict)
"""
Any extra parameters users wish to give to Nodes (e.g., parameters or settings around
system resource use).
"""


@dataclass
class NodeState:
"""
Dataclass that wraps the state of a node during a federation.
Args:
idx (NodeID): The ID of the node.
Throws:
- TypeError: This class cannot be directly instantiated. Only its children classes can be instantiated.
"""

idx: NodeID
cache: dict[str, t.Any] = field(
init=False, default_factory=dict, repr=False, hash=False
)

def __post_init__(self):
if type(self) is NodeState:
raise TypeError(
"Cannot instantiate an instance of `NodeState`. "
"Instead, you must instantiate instances of `WorkerState` or `AggrState`."
)


@dataclass
class AggrState(NodeState):
"""
The state of an Aggregator node.
Args:
children (t.Iterable[Node]): Child nodes in the topology.
aggr_model (t.Optional[Trainable]): Aggregated model.
"""

children: t.Iterable[Node]
aggr_model: t.Optional[Trainable] = None


@dataclass
class WorkerState(NodeState):
"""
The state of a Worker node.
Args:
global_model (t.Optional[Trainable]): ...
local_model (t.Optional[Trainable]): ...
"""

global_model: t.Optional[Trainable] = None
local_model: t.Optional[Trainable] = None
Loading

0 comments on commit e0b8ba1

Please sign in to comment.