diff --git a/demo.py b/demo.py index b56468f..5a3f4d3 100644 --- a/demo.py +++ b/demo.py @@ -1,6 +1,11 @@ +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns import torch import torch.nn as nn -from torch.utils.data import TensorDataset +from torch.utils.data import Subset +from torchvision.datasets import MNIST +from torchvision.transforms import ToTensor import flight as fl from flight.learning import federated_split @@ -14,15 +19,20 @@ class TrainingModule(TorchModule): def __init__(self): super().__init__() self.model = nn.Sequential( - nn.Linear(1, 10), - nn.Linear(10, 100), - nn.Linear(100, NUM_LABELS), + nn.Flatten(), + nn.Linear(28 * 28, 28 * 28 * 3), + nn.ReLU(), + nn.Linear(28 * 28 * 3, 28 * 28), + nn.ReLU(), + nn.Linear(28 * 28, 28), + nn.ReLU(), + nn.Linear(28, NUM_LABELS), ) def forward(self, x): return self.model(x) - def training_step(self, batch) -> TensorLoss: + def training_step(self, batch, batch_idx) -> TensorLoss: x, y = batch y_hat = self(x) return nn.functional.nll_loss(y_hat, y) @@ -32,19 +42,31 @@ def configure_optimizers(self) -> torch.optim.Optimizer: def main(): + data = MNIST( + root="~/Research/Data/Torch-Data/", + download=False, + train=False, + transform=ToTensor(), + ) + data = Subset(data, indices=list(range(2000))) topo = fl.flat_topology(10) + # exit(0) module = TrainingModule() fed_data = federated_split( topo=topo, - data=TensorDataset( - torch.randn(100, 1), torch.randint(low=0, high=NUM_LABELS, size=(100, 1)) - ), + # data=TensorDataset( + # torch.randn(100, 1), torch.randint(low=0, high=NUM_LABELS, size=(100, 1)) + # ), + data=data, num_labels=NUM_LABELS, label_alpha=100.0, sample_alpha=100.0, ) trained_module, records = fl.federated_fit(topo, module, fed_data, rounds=2) - print(records) + + df = pd.DataFrame.from_records(records) + sns.lineplot(df, x="round", y="train/loss") + plt.show() if __name__ == "__main__": diff --git a/flight/engine/control/__init__.py b/flight/engine/controllers/__init__.py similarity index 100% rename from flight/engine/control/__init__.py rename to flight/engine/controllers/__init__.py diff --git a/flight/engine/control/base.py b/flight/engine/controllers/base.py similarity index 100% rename from flight/engine/control/base.py rename to flight/engine/controllers/base.py diff --git a/flight/engine/control/local.py b/flight/engine/controllers/local.py similarity index 91% rename from flight/engine/control/local.py rename to flight/engine/controllers/local.py index 7c6d700..7374207 100644 --- a/flight/engine/control/local.py +++ b/flight/engine/controllers/local.py @@ -14,8 +14,8 @@ class LocalController(AbstractController): """ A local controller (similar to - [`SerialController`][flight.engine.control.serial.SerialController]) that instead - runs multiple functions at once using either threads or processes. + [`SerialController`][flight.engine.controllers.serial.SerialController]) that + instead runs multiple functions at once using either threads or processes. """ executor: Executor diff --git a/flight/engine/control/parsl.py b/flight/engine/controllers/parsl.py similarity index 100% rename from flight/engine/control/parsl.py rename to flight/engine/controllers/parsl.py diff --git a/flight/engine/control/serial.py b/flight/engine/controllers/serial.py similarity index 100% rename from flight/engine/control/serial.py rename to flight/engine/controllers/serial.py diff --git a/flight/engine/data/__init__.py b/flight/engine/data/__init__.py deleted file mode 100644 index fa21ba7..0000000 --- a/flight/engine/data/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .base import AbstractTransfer - -__all__ = ["AbstractTransfer"] diff --git a/flight/engine/data/base.py b/flight/engine/data/base.py deleted file mode 100644 index 4412df6..0000000 --- a/flight/engine/data/base.py +++ /dev/null @@ -1,20 +0,0 @@ -from __future__ import annotations - -import abc -import typing as t - -if t.TYPE_CHECKING: - pass - - -class AbstractTransfer(abc.ABC): - @abc.abstractmethod - def transfer(self, data: t.Any) -> t.Any: - """ - Abstract method to facilitate data transfer. - """ - - -class BaseTransfer(AbstractTransfer): - def transfer(self, data: t.Any) -> t.Any: - return data diff --git a/flight/engine/engine.py b/flight/engine/engine.py index 0eb1563..97ac4e3 100644 --- a/flight/engine/engine.py +++ b/flight/engine/engine.py @@ -3,12 +3,12 @@ import typing as t from concurrent.futures import Future -from .control.serial import SerialController -from .data.base import AbstractTransfer, BaseTransfer +from .controllers.serial import SerialController +from .transporters.base import AbstractTransporter, InMemoryTransporter if t.TYPE_CHECKING: from ..types import P, T - from .control.base import AbstractController + from .controllers.base import AbstractController class Engine: @@ -23,7 +23,7 @@ class Engine: compute resources (e.g., compute nodes, threads, processes). """ - transmitter: AbstractTransfer + transmitter: AbstractTransporter """ Object responsible for facilitating data transfer for the execution of jobs. This abstraction is used in the case of distributed and remote execution @@ -33,7 +33,7 @@ class Engine: def __init__( self, controller: AbstractController, - transmitter: AbstractTransfer, + transmitter: AbstractTransporter, ): """ Initializes the engine with the given controller and transmitter. @@ -41,11 +41,11 @@ def __init__( Args: controller (AbstractController): The controller responsible for submitting functions to be executed at the appropriate compute resources. - transmitter (AbstractTransfer): The object responsible for facilitating data - transfers for the execution of jobs. + transmitter (AbstractTransporter): The object responsible for facilitating + data transfers for the execution of jobs. """ self.controller = SerialController() - self.transmitter = BaseTransfer() + self.transmitter = InMemoryTransporter() def submit(self, fn: t.Callable, **kwargs: dict[str, t.Any]) -> Future: """ @@ -77,8 +77,8 @@ def transfer(self, data: P) -> T: @classmethod def setup( cls, - controller_kind: ..., - transmitter_kind: ..., + controller_kind: AbstractController, + transmitter_kind: AbstractTransporter, controller_cfg: dict[str, t.Any] | None = None, transmitter_cfg: dict[str, t.Any] | None = None, ) -> Engine: @@ -86,8 +86,8 @@ def setup( This helper method prepares a new `Engine` instance. Args: - controller_kind: ... - transmitter_kind: ... + controller_kind (AbstractController): ... + transmitter_kind (AbstractTransporter): ... controller_cfg (dict[str, t.Any]): ... transmitter_cfg (dict[str, t.Any]): ... @@ -95,6 +95,6 @@ def setup( An `Engine` instance based on the provided configurations. """ # TODO - controller: AbstractController = None - transmitter: AbstractTransfer = None + controller: AbstractController = None # noqa + transmitter: AbstractTransporter = None # noqa return cls(controller, transmitter) diff --git a/flight/engine/transporters/__init__.py b/flight/engine/transporters/__init__.py new file mode 100644 index 0000000..a7485b8 --- /dev/null +++ b/flight/engine/transporters/__init__.py @@ -0,0 +1,9 @@ +""" +This module contains implementations of _Data **Transporters**_ which are used to handle +how to "transport" the data (e.g., locally, across nodes at a distributed cluster, or +across remote resources). +""" + +from .base import AbstractTransporter, InMemoryTransporter + +__all__ = ["AbstractTransporter", "InMemoryTransporter"] diff --git a/flight/engine/transporters/base.py b/flight/engine/transporters/base.py new file mode 100644 index 0000000..e0f3f62 --- /dev/null +++ b/flight/engine/transporters/base.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import abc +import typing as t + + +class AbstractTransporter(abc.ABC): + @abc.abstractmethod + def transfer(self, data: t.Any) -> t.Any: + """ + Abstract method to facilitate data transfer. + """ + + +class InMemoryTransporter(AbstractTransporter): + """ + An in-memory transporter that simply returns the data as-is. + + This class does nothing fancy, it simply returns the data as-is. The need + for this class is that it adheres to the `AbstractTransporter` standard. + """ + + def transfer(self, data: t.Any) -> t.Any: + return data diff --git a/flight/engine/data/proxy.py b/flight/engine/transporters/proxy.py similarity index 94% rename from flight/engine/data/proxy.py rename to flight/engine/transporters/proxy.py index 03e47ba..40423cd 100644 --- a/flight/engine/data/proxy.py +++ b/flight/engine/transporters/proxy.py @@ -8,13 +8,13 @@ from proxystore.store import Store from ...federation.topologies import Topology -from .base import AbstractTransfer +from .base import AbstractTransporter if t.TYPE_CHECKING: from proxystore.proxy import Proxy -class ProxystoreTransfer(AbstractTransfer): +class ProxystoreTransfer(AbstractTransporter): def __init__(self, topo: Topology, name: str = "default") -> None: if not topo.proxystore_ready: raise ValueError( diff --git a/flight/engine/data/redis.py b/flight/engine/transporters/redis.py similarity index 100% rename from flight/engine/data/redis.py rename to flight/engine/transporters/redis.py diff --git a/flight/federation/commons.py b/flight/federation/commons.py new file mode 100644 index 0000000..83e6414 --- /dev/null +++ b/flight/federation/commons.py @@ -0,0 +1,20 @@ +from flight.learning import AbstractModule +from flight.learning.scikit import ScikitModule +from flight.learning.torch import TorchModule + + +def _test_scikit_global_module(): + pass + + +def _test_torch_global_module(): + pass + + +def test_global_module(module: AbstractModule): + if isinstance(module, TorchModule): + _test_torch_global_module() + elif isinstance(module, ScikitModule): + _test_scikit_global_module() + else: + raise ValueError(f"Unsupported module type: {type(module)}") diff --git a/flight/federation/fed_abs.py b/flight/federation/fed_abs.py index 9941022..69fff2a 100644 --- a/flight/federation/fed_abs.py +++ b/flight/federation/fed_abs.py @@ -140,7 +140,7 @@ def trainer_strategy(self) -> TrainerStrategy: def worker_task(self, node: Node, parent: Node) -> Future[Result]: """ Prepares the arguments for the worker function and submits the function using - the provided control plane via the given `Engine`. + the provided controllers plane via the given `Engine`. Args: node (Node): The worker node. diff --git a/flight/federation/fed_sync.py b/flight/federation/fed_sync.py index 73f9e89..2df918a 100644 --- a/flight/federation/fed_sync.py +++ b/flight/federation/fed_sync.py @@ -11,6 +11,7 @@ from .future_callbacks import all_futures_finished from .jobs.aggr import default_aggr_job from .jobs.types import AggrJobArgs +from .records import broadcast_records from .topologies.node import Node, NodeKind, AggrState from .topologies.topo import Topology from ..engine import Engine @@ -80,6 +81,15 @@ def federation_round(self, round_no: int) -> Result: self.engine.controller.shutdown() raise err + # TEST THE GLOBAL MODEL. + coord = self.topology.coordinator + test_data = self.data.test_data(coord) + if test_data: + _ = self.global_model.test_step(test_data) # TODO + test_results = {"test/acc": -1, "test/loss": -1} + broadcast_records(step_result.records, **test_results) + + # UPDATE PROGRESS BAR. self.global_model.set_params(step_result.params) if self._pbar: self._pbar.update() diff --git a/flight/federation/jobs/types.py b/flight/federation/jobs/types.py index a3e5961..acebf4a 100644 --- a/flight/federation/jobs/types.py +++ b/flight/federation/jobs/types.py @@ -11,7 +11,7 @@ if t.TYPE_CHECKING: from flight.types import Record - from flight.engine.data import AbstractTransfer + from flight.engine.transporters import AbstractTransporter from flight.strategies import AggrStrategy, TrainerStrategy, WorkerStrategy @@ -58,7 +58,7 @@ class AggrJobArgs: children: t.Sequence[Node] child_results: t.Sequence[Result] aggr_strategy: AggrStrategy - transfer: AbstractTransfer + transfer: AbstractTransporter @dataclass(slots=True, frozen=True) diff --git a/flight/federation/jobs/work.py b/flight/federation/jobs/work.py index 4cba90c..c123a71 100644 --- a/flight/federation/jobs/work.py +++ b/flight/federation/jobs/work.py @@ -66,7 +66,7 @@ def default_training_job(args: TrainJobArgs) -> Result: # TODO: Add this as an attr. of TrainArgJobs. trainer_init_params = dict(progress_bar=False) trainer_fit_params = dict() - trainer = TorchTrainer(**trainer_init_params) + trainer = TorchTrainer(node=args.node, **trainer_init_params) records = trainer.fit(node_state, local_model, data, **trainer_fit_params) case _: diff --git a/flight/federation/topologies/node.py b/flight/federation/topologies/node.py index 69cb069..d0fd470 100644 --- a/flight/federation/topologies/node.py +++ b/flight/federation/topologies/node.py @@ -150,7 +150,7 @@ def __getitem__(self, key: str) -> t.Any: def __setitem__(self, key: str, value: t.Any) -> None: """ - Setter function that stores a data item into the state's cache by key. + Setter function that stores a datum into the state's cache by key. Args: key (str): The key to store the data in cache for lookup. diff --git a/flight/fit.py b/flight/fit.py index f5e674d..311c9f4 100644 --- a/flight/fit.py +++ b/flight/fit.py @@ -4,14 +4,14 @@ import numpy as np -from .engine.control.base import AbstractController -from .engine.control.local import LocalController -from .engine.control.serial import SerialController +from .engine.controllers.base import AbstractController +from .engine.controllers.local import LocalController +from .engine.controllers.serial import SerialController from .federation import SyncFederation, Topology -from .federation.jobs.types import Result from .learning.base import AbstractDataModule, AbstractModule from .strategies import Strategy from .strategies.impl import FedSGD +from .types import Record def load_topology(raw_data: Topology | pathlib.Path | str | dict): @@ -49,7 +49,7 @@ def federated_fit( strategy: Strategy | str = "fedsgd", mode: str = "sync", fast_dev_run: bool = False, -) -> tuple[AbstractModule, list[Result]]: +) -> tuple[AbstractModule, list[Record]]: if strategy == "fedsgd": strategy = FedSGD() else: @@ -76,5 +76,5 @@ def federated_fit( case _: raise ValueError("Illegal value for argument `mode`.") - results = federation.start(rounds) - return module, results + records = federation.start(rounds) + return module, records diff --git a/flight/learning/params.py b/flight/learning/params.py index a9a7107..54864b6 100644 --- a/flight/learning/params.py +++ b/flight/learning/params.py @@ -1,21 +1,109 @@ -import abc +import typing as t +from enum import Enum, auto -import numpy.typing as npt +import numpy as np +import torch +from flight.learning import NpParams, TorchParams -class Params(abc.ABC): - @abc.abstractmethod - def numpy(self) -> dict[str, npt.NDArray]: - pass +class UnsupportedParameterKindError(ValueError): + """An Exception raised when an unsupported parameter kind is detected.""" -class NpParams(Params): - @abc.abstractmethod - def numpy(self) -> dict[str, npt.NDArray]: - pass + def __init__(self, message: str | None = None, *args): + if message is None: + message = ( + "The parameter kind is unknown or unsupported. " + "Please refer to the docs." + ) + super().__init__(message, *args) -class TorchParams(Params): - @abc.abstractmethod - def numpy(self) -> dict[str, npt.NDArray]: - pass +class InconsistentParamValuesError(ValueError): + """An Exception raised when the parameter value kinds are inconsistent.""" + + def __init__(self, message: str | None = None, *args): + if message is None: + message = "The parameter values are inconsistent. Please refer to the docs." + super().__init__(message, *args) + + +class ParamKinds(Enum): + NUMPY = auto() + TORCH = auto() + + +def infer_param_kind(param: t.Any) -> ParamKinds: + """ + Detect the kind of parameter. + + Args: + param (t.Any): The parameter to infer the type for. + + Returns: + The kind of parameter. + + Throws: + - `ValueError`: If the parameter kind is unknown or unsupported. + """ + if isinstance(param, np.ndarray): + return ParamKinds.NUMPY + elif isinstance(param, torch.Tensor): + return ParamKinds.TORCH + else: + raise UnsupportedParameterKindError() + + +def validate_param_kind(params: dict[str, t.Any]) -> ParamKinds: + """ + Validate the kind of parameters. + + This function returns the kind of parameters (similar to `infer_param_kind`), but + it will throw an error in the case where the parameters are not of the same kind. + + Args: + params: + + Returns: + The kind of parameters if they are of the same kind. Otherwise, an error is + thrown. + + Throws: + - `InconsistentParamValuesError`: If the parameter values are inconsistent. + """ + param_kinds = set(map(infer_param_kind, params.values())) + if len(param_kinds) != 1: + raise InconsistentParamValuesError() + return param_kinds.pop() + + +class Params: + def __init__(self, raw_params: dict[str, t.Any]): + self._raw_params = raw_params + self._inferred_kind = validate_param_kind(raw_params) + + def numpy(self) -> NpParams: + match self._inferred_kind: + case ParamKinds.NUMPY: + return self._raw_params + case ParamKinds.TORCH: + return {k: v.numpy() for k, v in self._raw_params.items()} + + def torch(self) -> TorchParams: + match self._inferred_kind: + case ParamKinds.TORCH: + return self._raw_params + case ParamKinds.NUMPY: + return {k: torch.from_numpy(v) for k, v in self._raw_params.items()} + + +# class NpParams(Params): +# @abc.abstractmethod +# def numpy(self) -> dict[str, npt.NDArray]: +# pass +# +# +# class TorchParams(Params): +# @abc.abstractmethod +# def numpy(self) -> dict[str, npt.NDArray]: +# pass diff --git a/flight/learning/prototypes.py b/flight/learning/scikit/params.py similarity index 100% rename from flight/learning/prototypes.py rename to flight/learning/scikit/params.py diff --git a/flight/learning/torch/params.py b/flight/learning/torch/params.py new file mode 100644 index 0000000..e69de29 diff --git a/flight/learning/torch/trainer.py b/flight/learning/torch/trainer.py index bcbbb0b..cf8a568 100644 --- a/flight/learning/torch/trainer.py +++ b/flight/learning/torch/trainer.py @@ -117,6 +117,7 @@ def fit( ckpt_path: _PATH | None = None, ) -> list[Record]: """ + Fits (or trains) a PyTorch module on a given module. Args: node_state (WorkerState): @@ -143,12 +144,13 @@ def fit( if not isinstance(train_dataloader, DataLoader): raise TypeError( - "Method for argument `data.train_data(.)` must return a `DataLoader`." + "Method for argument `data.train_data(.)` " + "must return a `DataLoader`." ) if not isinstance(valid_dataloader, DataLoader | None): raise TypeError( - "Method for argument `data.valid_data(.)` must return a `DataLoader` " - "or `None`." + "Method for argument `data.valid_data(.)` " + "must return a `DataLoader` or `None`." ) pbar_prefix = f"TorchTrainer(NodeID={self.node.idx})" diff --git a/flight/learning/torch/utils.py b/flight/learning/torch/utils.py index 72cac0b..2816250 100644 --- a/flight/learning/torch/utils.py +++ b/flight/learning/torch/utils.py @@ -26,9 +26,10 @@ class FederatedDataModule(TorchDataModule): topology. This is especially helpful for simulation-based federations that are run with - Flight. Rather than needing to manually define the logic to load data that is + Flight. Rather than needing to manually define the logic to load data that are sharded across workers in a federation, this class simply requires the original - dataset and the indices for training, testing, and validation data for each worker. + dataset and the indices for training, testing, and validation data for each + worker. A good analogy for this class is to think of it as the federated version of PyTorch's [`Subset`](https://pytorch.org/docs/stable/data.html# diff --git a/testing/fixtures.py b/testing/fixtures.py index ce4e074..124ed7f 100644 --- a/testing/fixtures.py +++ b/testing/fixtures.py @@ -9,8 +9,8 @@ from torch.nn import functional as F from torch.utils.data import DataLoader, Subset, TensorDataset -from flight.engine.control.parsl import ParslController -from flight.engine.control.serial import SerialController +from flight.engine.controllers.parsl import ParslController +from flight.engine.controllers.serial import SerialController from flight.federation.topologies import Node from flight.federation.topologies.node import NodeKind, WorkerState from flight.learning.scikit import ScikitDataModule diff --git a/tests/federation/jobs/test_aggr_job.py b/tests/federation/jobs/test_aggr_job.py index b5dcf31..52741cc 100644 --- a/tests/federation/jobs/test_aggr_job.py +++ b/tests/federation/jobs/test_aggr_job.py @@ -3,7 +3,7 @@ import torch from torch.utils.data import DataLoader, Subset, TensorDataset -from flight.engine.data.base import BaseTransfer +from flight.engine.transporters.base import InMemoryTransporter from flight.federation.jobs.aggr import default_aggr_job from flight.federation.jobs.types import Result, AggrJobArgs from flight.federation.topologies import Node @@ -107,7 +107,7 @@ def aggr_args(node, parent, result) -> AggrJobArgs: children=[node], child_results=[result], aggr_strategy=DefaultAggrStrategy(), # TODO: We need to resolve this typing. - transfer=BaseTransfer(), + transfer=InMemoryTransporter(), ) diff --git a/tests/learning/test_params.py b/tests/learning/test_params.py new file mode 100644 index 0000000..ba19408 --- /dev/null +++ b/tests/learning/test_params.py @@ -0,0 +1,35 @@ +import numpy as np +import pytest +import torch + +from flight.learning.params import ( + validate_param_kind, + ParamKinds, + infer_param_kind, + InconsistentParamValuesError, +) + + +@pytest.fixture +def param_data() -> list[float]: + return [0.0, 1.0, 2.0] + + +def test_validate_numpy_params(param_data): + p = np.array(param_data) + params = {f"p{i}": p for i in range(10)} + assert infer_param_kind(p) == ParamKinds.NUMPY + assert validate_param_kind(params) == ParamKinds.NUMPY + + +def test_validate_torch_params(param_data): + p = torch.tensor(param_data) + params = {f"p{i}": p for i in range(10)} + assert infer_param_kind(p) == ParamKinds.TORCH + assert validate_param_kind(params) == ParamKinds.TORCH + + +def test_inconsistent_params(param_data): + with pytest.raises(InconsistentParamValuesError): + bad_params = {"p0": torch.tensor(param_data), "p1": np.array(param_data)} + validate_param_kind(bad_params)