Skip to content

Commit

Permalink
working records from fed_fit
Browse files Browse the repository at this point in the history
  • Loading branch information
nathaniel-hudson committed Nov 12, 2024
1 parent 50468ac commit 5de98db
Show file tree
Hide file tree
Showing 28 changed files with 273 additions and 85 deletions.
40 changes: 31 additions & 9 deletions demo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset
from torch.utils.data import Subset
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

import flight as fl
from flight.learning import federated_split
Expand All @@ -14,15 +19,20 @@ class TrainingModule(TorchModule):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(1, 10),
nn.Linear(10, 100),
nn.Linear(100, NUM_LABELS),
nn.Flatten(),
nn.Linear(28 * 28, 28 * 28 * 3),
nn.ReLU(),
nn.Linear(28 * 28 * 3, 28 * 28),
nn.ReLU(),
nn.Linear(28 * 28, 28),
nn.ReLU(),
nn.Linear(28, NUM_LABELS),
)

def forward(self, x):
return self.model(x)

def training_step(self, batch) -> TensorLoss:
def training_step(self, batch, batch_idx) -> TensorLoss:
x, y = batch
y_hat = self(x)
return nn.functional.nll_loss(y_hat, y)
Expand All @@ -32,19 +42,31 @@ def configure_optimizers(self) -> torch.optim.Optimizer:


def main():
data = MNIST(
root="~/Research/Data/Torch-Data/",
download=False,
train=False,
transform=ToTensor(),
)
data = Subset(data, indices=list(range(2000)))
topo = fl.flat_topology(10)
# exit(0)
module = TrainingModule()
fed_data = federated_split(
topo=topo,
data=TensorDataset(
torch.randn(100, 1), torch.randint(low=0, high=NUM_LABELS, size=(100, 1))
),
# data=TensorDataset(
# torch.randn(100, 1), torch.randint(low=0, high=NUM_LABELS, size=(100, 1))
# ),
data=data,
num_labels=NUM_LABELS,
label_alpha=100.0,
sample_alpha=100.0,
)
trained_module, records = fl.federated_fit(topo, module, fed_data, rounds=2)
print(records)

df = pd.DataFrame.from_records(records)
sns.lineplot(df, x="round", y="train/loss")
plt.show()


if __name__ == "__main__":
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
class LocalController(AbstractController):
"""
A local controller (similar to
[`SerialController`][flight.engine.control.serial.SerialController]) that instead
runs multiple functions at once using either threads or processes.
[`SerialController`][flight.engine.controllers.serial.SerialController]) that
instead runs multiple functions at once using either threads or processes.
"""

executor: Executor
Expand Down
File renamed without changes.
File renamed without changes.
3 changes: 0 additions & 3 deletions flight/engine/data/__init__.py

This file was deleted.

20 changes: 0 additions & 20 deletions flight/engine/data/base.py

This file was deleted.

28 changes: 14 additions & 14 deletions flight/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import typing as t
from concurrent.futures import Future

from .control.serial import SerialController
from .data.base import AbstractTransfer, BaseTransfer
from .controllers.serial import SerialController
from .transporters.base import AbstractTransporter, InMemoryTransporter

if t.TYPE_CHECKING:
from ..types import P, T
from .control.base import AbstractController
from .controllers.base import AbstractController


class Engine:
Expand All @@ -23,7 +23,7 @@ class Engine:
compute resources (e.g., compute nodes, threads, processes).
"""

transmitter: AbstractTransfer
transmitter: AbstractTransporter
"""
Object responsible for facilitating data transfer for the execution of jobs.
This abstraction is used in the case of distributed and remote execution
Expand All @@ -33,19 +33,19 @@ class Engine:
def __init__(
self,
controller: AbstractController,
transmitter: AbstractTransfer,
transmitter: AbstractTransporter,
):
"""
Initializes the engine with the given controller and transmitter.
Args:
controller (AbstractController): The controller responsible for submitting
functions to be executed at the appropriate compute resources.
transmitter (AbstractTransfer): The object responsible for facilitating data
transfers for the execution of jobs.
transmitter (AbstractTransporter): The object responsible for facilitating
data transfers for the execution of jobs.
"""
self.controller = SerialController()
self.transmitter = BaseTransfer()
self.transmitter = InMemoryTransporter()

def submit(self, fn: t.Callable, **kwargs: dict[str, t.Any]) -> Future:
"""
Expand Down Expand Up @@ -77,24 +77,24 @@ def transfer(self, data: P) -> T:
@classmethod
def setup(
cls,
controller_kind: ...,
transmitter_kind: ...,
controller_kind: AbstractController,
transmitter_kind: AbstractTransporter,
controller_cfg: dict[str, t.Any] | None = None,
transmitter_cfg: dict[str, t.Any] | None = None,
) -> Engine:
"""
This helper method prepares a new `Engine` instance.
Args:
controller_kind: ...
transmitter_kind: ...
controller_kind (AbstractController): ...
transmitter_kind (AbstractTransporter): ...
controller_cfg (dict[str, t.Any]): ...
transmitter_cfg (dict[str, t.Any]): ...
Returns:
An `Engine` instance based on the provided configurations.
"""
# TODO
controller: AbstractController = None
transmitter: AbstractTransfer = None
controller: AbstractController = None # noqa
transmitter: AbstractTransporter = None # noqa
return cls(controller, transmitter)
9 changes: 9 additions & 0 deletions flight/engine/transporters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
This module contains implementations of _Data **Transporters**_ which are used to handle
how to "transport" the data (e.g., locally, across nodes at a distributed cluster, or
across remote resources).
"""

from .base import AbstractTransporter, InMemoryTransporter

__all__ = ["AbstractTransporter", "InMemoryTransporter"]
24 changes: 24 additions & 0 deletions flight/engine/transporters/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

import abc
import typing as t


class AbstractTransporter(abc.ABC):
@abc.abstractmethod
def transfer(self, data: t.Any) -> t.Any:
"""
Abstract method to facilitate data transfer.
"""


class InMemoryTransporter(AbstractTransporter):
"""
An in-memory transporter that simply returns the data as-is.
This class does nothing fancy, it simply returns the data as-is. The need
for this class is that it adheres to the `AbstractTransporter` standard.
"""

def transfer(self, data: t.Any) -> t.Any:
return data
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from proxystore.store import Store

from ...federation.topologies import Topology
from .base import AbstractTransfer
from .base import AbstractTransporter

if t.TYPE_CHECKING:
from proxystore.proxy import Proxy


class ProxystoreTransfer(AbstractTransfer):
class ProxystoreTransfer(AbstractTransporter):
def __init__(self, topo: Topology, name: str = "default") -> None:
if not topo.proxystore_ready:
raise ValueError(
Expand Down
File renamed without changes.
20 changes: 20 additions & 0 deletions flight/federation/commons.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from flight.learning import AbstractModule
from flight.learning.scikit import ScikitModule
from flight.learning.torch import TorchModule


def _test_scikit_global_module():
pass


def _test_torch_global_module():
pass


def test_global_module(module: AbstractModule):
if isinstance(module, TorchModule):
_test_torch_global_module()
elif isinstance(module, ScikitModule):
_test_scikit_global_module()
else:
raise ValueError(f"Unsupported module type: {type(module)}")
2 changes: 1 addition & 1 deletion flight/federation/fed_abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def trainer_strategy(self) -> TrainerStrategy:
def worker_task(self, node: Node, parent: Node) -> Future[Result]:
"""
Prepares the arguments for the worker function and submits the function using
the provided control plane via the given `Engine`.
the provided controllers plane via the given `Engine`.
Args:
node (Node): The worker node.
Expand Down
10 changes: 10 additions & 0 deletions flight/federation/fed_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .future_callbacks import all_futures_finished
from .jobs.aggr import default_aggr_job
from .jobs.types import AggrJobArgs
from .records import broadcast_records
from .topologies.node import Node, NodeKind, AggrState
from .topologies.topo import Topology
from ..engine import Engine
Expand Down Expand Up @@ -80,6 +81,15 @@ def federation_round(self, round_no: int) -> Result:
self.engine.controller.shutdown()
raise err

# TEST THE GLOBAL MODEL.
coord = self.topology.coordinator
test_data = self.data.test_data(coord)
if test_data:
_ = self.global_model.test_step(test_data) # TODO
test_results = {"test/acc": -1, "test/loss": -1}
broadcast_records(step_result.records, **test_results)

# UPDATE PROGRESS BAR.
self.global_model.set_params(step_result.params)
if self._pbar:
self._pbar.update()
Expand Down
4 changes: 2 additions & 2 deletions flight/federation/jobs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

if t.TYPE_CHECKING:
from flight.types import Record
from flight.engine.data import AbstractTransfer
from flight.engine.transporters import AbstractTransporter
from flight.strategies import AggrStrategy, TrainerStrategy, WorkerStrategy


Expand Down Expand Up @@ -58,7 +58,7 @@ class AggrJobArgs:
children: t.Sequence[Node]
child_results: t.Sequence[Result]
aggr_strategy: AggrStrategy
transfer: AbstractTransfer
transfer: AbstractTransporter


@dataclass(slots=True, frozen=True)
Expand Down
2 changes: 1 addition & 1 deletion flight/federation/jobs/work.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def default_training_job(args: TrainJobArgs) -> Result:
# TODO: Add this as an attr. of TrainArgJobs.
trainer_init_params = dict(progress_bar=False)
trainer_fit_params = dict()
trainer = TorchTrainer(**trainer_init_params)
trainer = TorchTrainer(node=args.node, **trainer_init_params)
records = trainer.fit(node_state, local_model, data, **trainer_fit_params)

case _:
Expand Down
2 changes: 1 addition & 1 deletion flight/federation/topologies/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __getitem__(self, key: str) -> t.Any:

def __setitem__(self, key: str, value: t.Any) -> None:
"""
Setter function that stores a data item into the state's cache by key.
Setter function that stores a datum into the state's cache by key.
Args:
key (str): The key to store the data in cache for lookup.
Expand Down
14 changes: 7 additions & 7 deletions flight/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

import numpy as np

from .engine.control.base import AbstractController
from .engine.control.local import LocalController
from .engine.control.serial import SerialController
from .engine.controllers.base import AbstractController
from .engine.controllers.local import LocalController
from .engine.controllers.serial import SerialController
from .federation import SyncFederation, Topology
from .federation.jobs.types import Result
from .learning.base import AbstractDataModule, AbstractModule
from .strategies import Strategy
from .strategies.impl import FedSGD
from .types import Record


def load_topology(raw_data: Topology | pathlib.Path | str | dict):
Expand Down Expand Up @@ -49,7 +49,7 @@ def federated_fit(
strategy: Strategy | str = "fedsgd",
mode: str = "sync",
fast_dev_run: bool = False,
) -> tuple[AbstractModule, list[Result]]:
) -> tuple[AbstractModule, list[Record]]:
if strategy == "fedsgd":
strategy = FedSGD()
else:
Expand All @@ -76,5 +76,5 @@ def federated_fit(
case _:
raise ValueError("Illegal value for argument `mode`.")

results = federation.start(rounds)
return module, results
records = federation.start(rounds)
return module, records
Loading

0 comments on commit 5de98db

Please sign in to comment.