diff --git a/.flake8 b/.flake8 index f8418ad..00d1444 100644 --- a/.flake8 +++ b/.flake8 @@ -5,7 +5,10 @@ per-file-ignores = # TODO: Change to 88 later for black max-line-length = 120 -exclude: - quickstart/ +exclude = + quickstart + flox + flox_examples + flox_tests diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ace757b..8f38837 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,3 +33,9 @@ repos: hooks: - id: flake8 additional_dependencies: ['flake8-bugbear==22.10.27'] + +exclude: + flox/, + flox_examples/, + flox_tests/, + quickstart/ diff --git a/flight/federation/fed_abs.py b/flight/federation/fed_abs.py index 60828ad..af435cc 100644 --- a/flight/federation/fed_abs.py +++ b/flight/federation/fed_abs.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import abc import typing as t from concurrent.futures import Future @@ -6,18 +8,19 @@ from flight.strategies.coord import CoordStrategy from flight.strategies.trainer import TrainerStrategy from flight.strategies.worker import WorkerStrategy + +from ..learning.datasets.loadable import DataLoadable +from ..types import Record from .jobs.types import Result, TrainJob, TrainJobArgs from .jobs.work import default_training_job from .topologies.node import Node from .topologies.topo import Topology -from ..learning.datasets import DataLoadable if t.TYPE_CHECKING: from .fed_sync import Engine Strategy: t.TypeAlias = t.Any Module: t.TypeAlias = t.Any - Record: t.TypeAlias = dict[str, t.Any] class Federation(abc.ABC): diff --git a/flight/federation/fed_async.py b/flight/federation/fed_async.py index cec3924..acc80b7 100644 --- a/flight/federation/fed_async.py +++ b/flight/federation/fed_async.py @@ -1,5 +1,19 @@ +import typing as t + from .fed_abs import Federation +from .topologies.node import Node, NodeKind class AsyncFederation(Federation): - pass + def __init__(self): + pass + + def start_aggregator_task( + self, + node: Node, + selected_children: t.Sequence[Node], + ) -> Future[Result]: + raise NotImplementedError( + "This method is not implemented. Async federations only support 2-tier topologies " + "(i.e., there are no intermediate aggregators)." + ) diff --git a/flight/federation/fed_sync.py b/flight/federation/fed_sync.py index 32ed5ff..0631cdf 100644 --- a/flight/federation/fed_sync.py +++ b/flight/federation/fed_sync.py @@ -2,15 +2,16 @@ import typing as t from concurrent.futures import Future -from flox.federation.topologies import Node, NodeKind -from .fed_abs import Federation -from .topologies.topo import Topology +from flight.learning.modules.base.module import Trainable + from ..learning.datasets import DataLoadable -from ..learning.module import Trainable from ..strategies.base import Strategy +from .fed_abs import Federation +from .topologies.node import Node, NodeKind +from .topologies.topo import Topology if t.TYPE_CHECKING: - from .jobs.types import Result, AggrJobArgs + from .jobs.types import AggrJobArgs, Result Engine: t.TypeAlias = t.Any diff --git a/flight/federation/jobs/result.py b/flight/federation/jobs/result.py index 029ac9e..26a3f8b 100644 --- a/flight/federation/jobs/result.py +++ b/flight/federation/jobs/result.py @@ -3,22 +3,22 @@ from proxystore.proxy import Proxy from pydantic.dataclasses import dataclass -from flight.learning.module import RecordList -from flight.strategies import Params +from flight.federation.topologies.node import NodeID, NodeState +from flight.learning.types import Params +from flight.types import Record -NodeID: t.TypeAlias = t.Hashable -NodeState: t.TypeAlias = tuple - -# TODO: Remove config when all type definitions have been resolvedß +# TODO: Remove config when all type definitions have been resolved @dataclass(config={"arbitrary_types_allowed": True}) class Result: state: NodeState node_idx: NodeID params: Params - records: RecordList + records: list[Record] cache: dict[str, t.Any] AbstractResult: t.TypeAlias = Result | Proxy[Result] -"""Helper type alias for a `Result` or a proxy to a `Result`.""" +""" +Helper type alias for a `Result` or a proxy to a `Result`. +""" diff --git a/flight/federation/jobs/types.py b/flight/federation/jobs/types.py index 69e70ce..b49049b 100644 --- a/flight/federation/jobs/types.py +++ b/flight/federation/jobs/types.py @@ -1,32 +1,32 @@ +from __future__ import annotations + import typing as t from concurrent.futures import Future -from dataclasses import dataclass - -import pydantic as pyd +from dataclasses import dataclass, field -from flight.federation.topologies.node import Node -from flight.learning.module import RecordList +from flight.federation.topologies.node import Node, NodeState, WorkerState +from flight.learning.datasets.loadable import DataLoadable +from flight.learning.modules.base import Record +from flight.learning.modules.torch import FlightModule +from flight.learning.types import Params if t.TYPE_CHECKING: - from flight.learning.datasets import DataLoadable - from flight.learning.module import FlightModule from flight.strategies.trainer import TrainerStrategy from flight.strategies.worker import WorkerStrategy - NodeState: t.TypeAlias = t.Any - Params: t.TypeAlias = t.Any - - -@pyd.dataclasses.dataclass -class Result(pyd.BaseModel): - node: Node = pyd.Field() - node_state: NodeState = pyd.Field() - params: Params = pyd.Field() - records: RecordList = pyd.Field() - cache: dict[str, t.Any] = pyd.Field(default_factory=dict, init=False) - -AggrJob: t.TypeAlias = t.Callable[[Node, Node], Result] +@dataclass +class Result: + node: Node + """The node that produced this result during a federation.""" + node_state: NodeState + """The current state of the node that returned a given result during a federation.""" + params: Params + """Parameters returned as part of a result from a single Node in a federation.""" + records: list[Record] = field(default_factory=list) + """List of records for model training/aggregation metrics.""" + extra: dict[str, t.Any] = field(default_factory=dict) + """Extra data recorded by a node during the runtime of its job.""" # class TrainJob(t.Protocol): @@ -57,6 +57,7 @@ class TrainJobArgs: node: Node parent: Node + node_state: WorkerState model: FlightModule data: DataLoadable worker_strategy: WorkerStrategy @@ -64,4 +65,7 @@ class TrainJobArgs: AggrJob: t.TypeAlias = t.Callable[[AggrJobArgs], Result] +"""Function signature for aggregation jobs.""" + TrainJob: t.TypeAlias = t.Callable[[TrainJobArgs], Result] +"""Function signature for loca training jobs.""" diff --git a/flight/federation/jobs/work.py b/flight/federation/jobs/work.py index 9a7e4d5..3a0d886 100644 --- a/flight/federation/jobs/work.py +++ b/flight/federation/jobs/work.py @@ -1,5 +1,6 @@ -import typing as t +from __future__ import annotations +import typing as t if t.TYPE_CHECKING: from flight.federation.jobs.types import Result, TrainJobArgs @@ -11,15 +12,17 @@ def default_training_job(args: TrainJobArgs) -> Result: from torch.utils.data import DataLoader - hparams = trainer_strategy.trainer_hparams() + from flight.learning.trainers.torch import TorchTrainer + + hparams = args.trainer_strategy.trainer_hparams() training_start = datetime.now() - state = worker_strategy.start_work() + state = args.worker_strategy.start_work() data = { - "train": data.load(node, "train"), - "valid": data.load(node, "valid"), + "train": args.data.load(args.node, "train"), + "valid": args.data.load(args.node, "valid"), } train_dataloader = DataLoader( @@ -27,24 +30,34 @@ def default_training_job(args: TrainJobArgs) -> Result: **{key: val for (key, val) in hparams if key.startswith("dataloader.train.")}, ) - trainer = Trainer(trainer_strategy) + trainer = TorchTrainer(args.trainer_strategy) + local_model = args.model.copy() + optimizer = args.model.configure_optimizers() trainer.fit( + args.node_state, local_model, optimizer, train_dataloader, - node_state, **{key: val for (key, val) in hparams if key.startswith("trainer.")}, ) - state = worker_strategy.end_work() + state = args.worker_strategy.end_work() training_end = datetime.now() history = { - "node_idx": node.idx, - "node_kind": node.kind, - "parent_idx": parent.idx, - "parent_kind": parent.kind, + "node_idx": args.node.idx, + "node_kind": args.node.kind, + "parent_idx": args.parent.idx, + "parent_kind": args.parent.kind, "training_start": training_start, "training_end": training_end, } + + return Result( + node=..., + node_state=..., + params=..., + records=..., + extra=..., + ) diff --git a/flight/federation/topologies/__init__.py b/flight/federation/topologies/__init__.py index e69de29..022a916 100644 --- a/flight/federation/topologies/__init__.py +++ b/flight/federation/topologies/__init__.py @@ -0,0 +1,3 @@ +from .node import Node + +__all__ = ["Node"] diff --git a/flight/federation/topologies/node.py b/flight/federation/topologies/node.py index 7c2d9b1..7e885df 100644 --- a/flight/federation/topologies/node.py +++ b/flight/federation/topologies/node.py @@ -1,11 +1,16 @@ import typing as t +from dataclasses import dataclass, field from enum import Enum from uuid import UUID import pydantic as pyd -NodeID: t.TypeAlias = int | str -"""ID of nodes in Flight topologies; can either be of type `int` or `str`.""" +from flight.learning.modules.base import Trainable + +NodeID: t.TypeAlias = t.Union[int, str] +""" +ID of nodes in Flight topologies; can either be of type `int` or `str`. +""" class NodeKind(str, Enum): @@ -27,12 +32,80 @@ class Node(pyd.BaseModel): [`Topology`][flight.federation.topologies.topo.Topology] class.""" idx: NodeID - """The ID of the node.""" + """ + The ID of the node. + """ + kind: NodeKind - """The kind of Node---indicates its *role* in a federation.""" + """ + The kind of Node---indicates its *role* in a federation. + """ + globus_comp_id: UUID | None = pyd.Field(default=None) - """Globus Compute UUID for remote execution.""" + """ + Globus Compute UUID for remote execution. + """ + proxystore_id: UUID | None = pyd.Field(default=None) - """ProxyStore UUID for data transfer for remote execution with Globus Compute.""" - extra: dict[str, t.Any] | None = pyd.Field(default=None) - """Any extra parameters users wish to give to Nodes (e.g., parameters or settings around system resource use).""" + """ + ProxyStore UUID for data transfer for remote execution with Globus Compute. + """ + + extra: dict[str, t.Any] = pyd.Field(default_factory=dict) + """ + Any extra parameters users wish to give to Nodes (e.g., parameters or settings around + system resource use). + """ + + +@dataclass +class NodeState: + """ + Dataclass that wraps the state of a node during a federation. + + Args: + idx (NodeID): The ID of the node. + + Throws: + - TypeError: This class cannot be directly instantiated. Only its children classes can be instantiated. + """ + + idx: NodeID + cache: dict[str, t.Any] = field( + init=False, default_factory=dict, repr=False, hash=False + ) + + def __post_init__(self): + if type(self) is NodeState: + raise TypeError( + "Cannot instantiate an instance of `NodeState`. " + "Instead, you must instantiate instances of `WorkerState` or `AggrState`." + ) + + +@dataclass +class AggrState(NodeState): + """ + The state of an Aggregator node. + + Args: + children (t.Iterable[Node]): Child nodes in the topology. + aggr_model (t.Optional[Trainable]): Aggregated model. + """ + + children: t.Iterable[Node] + aggr_model: t.Optional[Trainable] = None + + +@dataclass +class WorkerState(NodeState): + """ + The state of a Worker node. + + Args: + global_model (t.Optional[Trainable]): ... + local_model (t.Optional[Trainable]): ... + """ + + global_model: t.Optional[Trainable] = None + local_model: t.Optional[Trainable] = None diff --git a/flight/federation/topologies/topo.py b/flight/federation/topologies/topo.py index e2bed26..a424c12 100644 --- a/flight/federation/topologies/topo.py +++ b/flight/federation/topologies/topo.py @@ -21,7 +21,7 @@ def resolve_node_or_idx(node_or_idx: Node | NodeID) -> NodeID: if isinstance(node_or_idx, Node): return node_or_idx.idx - elif isinstance(node_or_idx, int | str): # mypy doesn't accept just using `NodeID` + elif isinstance(node_or_idx, NodeID): # type: ignore # (mypy wants `int | str`) return node_or_idx else: raise ValueError("Argument `node_or_idx` must be of type `Node` or `NodeID`.") @@ -105,10 +105,12 @@ def nodes(self, kind: NodeKind | str | None = None) -> t.Iterator[Node]: included in the returned iterator. Args: - kind: + kind (NodeKind | str | None): The kind of nodes to include in the iterator. If `None`, + then all nodes in the topology are included in the returned iterator. Raises: - - `ValueError` in the event the user provides an illegal `str` (see docs for `NodeKind` enum). + - `ValueError` in the event the user provides an illegal `str` for arg `kind` + (see docs for `NodeKind` enum). Examples: >>> nodes: list[Node] = ... @@ -357,7 +359,9 @@ def from_yaml(cls, path: pathlib.Path | str) -> Topology: def validate(topo: Topology) -> None: """ - Validates + Validates whether the provided topology is structurally legal or not. If the topology is not + structurally legal, then `TopologyException` is thrown. If no exception is thrown, then the + topology is legal. Args: topo (Topology): The `Topology` instance to validate. @@ -366,9 +370,6 @@ def validate(topo: Topology) -> None: - `TopologyException` if an illegal topology has been defined based on Nodes, edges/links, and underlying graph. Exception messages will more explicitly state the exact issue. Refer to the docs for more information about the requirements for a legal Flight topology. - - Returns: - `True` if the graph is legitimate; `False` otherwise. """ nodes: t.Mapping[NodeID, Node] = topo._nodes # noqa edges: list[NodeLink] = topo._edges # noqa diff --git a/flight/launch/__init__.py b/flight/launch/__init__.py index e69de29..0a16576 100644 --- a/flight/launch/__init__.py +++ b/flight/launch/__init__.py @@ -0,0 +1,15 @@ +""" +This module provides a Flight CLI for launching federations. + +```sh title="Basic use of Flight CLI." +python3 -m flight.launch --config my-setup.yaml +``` + +```sh title="Configuring federation with Flight CLI args." +python3 -m flight.launch \ + --topology.kind hub-spoke \ + --topology.num_workers 10 \ + --dataset mnist \ + --output.results 'my_results.csv' +``` +""" diff --git a/flight/learning/__init__.py b/flight/learning/__init__.py index e69de29..bd05aee 100644 --- a/flight/learning/__init__.py +++ b/flight/learning/__init__.py @@ -0,0 +1,23 @@ +""" +This module provides the standard interfaces, protocols, and classes for defining and training neural networks in Flight. +This module also provides classes and utility functions for loading (and setting up) datasets for Flight federations. + +Flight currently provides support for the following deep learning frameworks: + +- PyTorch +- PyTorch Lightning +- Scikit-Learn (namely, `MLPRegressor` and `MLPClassifier`) + +## Trainable Models +... + +## Trainer +... + +| Trainer | Module | +| :------ | :----- | +| TorchTrainer | ... | +| ScitkitTrainer | ... | +| LightningTrainer | ... | +| CustomTrainer | ... | +""" diff --git a/flight/learning/datasets.py b/flight/learning/datasets.py deleted file mode 100644 index 6d34807..0000000 --- a/flight/learning/datasets.py +++ /dev/null @@ -1,8 +0,0 @@ -import typing as t - -from flight.federation.topologies.node import Node - - -class DataLoadable(t.Protocol): - def load(self, node: Node): - pass diff --git a/flight/learning/datasets/__init__.py b/flight/learning/datasets/__init__.py new file mode 100644 index 0000000..7c7f780 --- /dev/null +++ b/flight/learning/datasets/__init__.py @@ -0,0 +1,3 @@ +from .loadable import DataLoadable + +__all__ = ["DataLoadable"] diff --git a/flight/learning/datasets/base.py b/flight/learning/datasets/base.py new file mode 100644 index 0000000..e69de29 diff --git a/flight/learning/datasets/loadable.py b/flight/learning/datasets/loadable.py new file mode 100644 index 0000000..0ddebab --- /dev/null +++ b/flight/learning/datasets/loadable.py @@ -0,0 +1,25 @@ +import typing as t + +from flight.federation.topologies.node import Node + + +class DataLoadable(t.Protocol): + """ + The `DataLoadable` is a protocol that defines the key functionalities necessary to load data into + a federation with Flight. + + Data in federated learning are naturally decentralized across multiple nodes/endpoints. In real-world + settings, we do not need to worry about modeling the decentralization. But, in simulated settings for + rapid prototyping we will need to worry about how to break up central datasets into some simulated + decentralized/federated data distribution. + + A `DataLoadable` object will need to support two main use cases: + + 1. simulated workflows where data are centrally located on one machine and we split it up into separate + subsets for training across multiple nodes/endpoints, + 2. real-world workflows where data already exist on the nodes/endpoints and we only need to **load them + from disc**. + """ + + def load(self, node: Node, mode: t.Literal["train", "test", "validation"]): + pass diff --git a/flight/learning/datasets/torch.py b/flight/learning/datasets/torch.py new file mode 100644 index 0000000..ccb1811 --- /dev/null +++ b/flight/learning/datasets/torch.py @@ -0,0 +1,6 @@ +from __future__ import annotations + +import typing as t + +if t.TYPE_CHECKING: + pass diff --git a/flight/learning/metrics.py b/flight/learning/metrics.py index cc1a8b0..ffd95f7 100644 --- a/flight/learning/metrics.py +++ b/flight/learning/metrics.py @@ -1,17 +1,71 @@ +from __future__ import annotations + +import datetime as dt +import sys import typing as t +DATE_RECORD_KEY = "date" + +# if sys.version_info >= (3, 10): +# from typing import TypeAlias +# else: +# from typing_extensions import TypeAlias + + +if t.TYPE_CHECKING: + from types import TracebackType -class MetricLogger(t.Protocol): - def log(self): + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self + + +class RecordLogger(t.Protocol): + def log(self, **kwargs: t.Mapping[str, t.Any]) -> None: pass - def log_dict(self): + +class NullLogger: + def __init__(self): pass + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_traceback: TracebackType | None, + ): + return + + def log(self, **kwargs: t.Mapping[str, t.Any]) -> None: + """Logs and records nothing.""" + return + class InMemoryRecordLogger: - pass + def __init__(self): + self.records = [] + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_traceback: TracebackType | None, + ): + pass + + def log(self, **kwargs: t.Mapping[str, t.Any]) -> None: + records = {name: value for name, value in kwargs.items()} + records.update({DATE_RECORD_KEY: dt.datetime.now()}) + self.records.append(records) -class DiscRecordLogger: +class JsonRecordLogger: pass diff --git a/flight/learning/module.py b/flight/learning/module.py deleted file mode 100644 index 33ffe8c..0000000 --- a/flight/learning/module.py +++ /dev/null @@ -1,31 +0,0 @@ -import typing as t - -import torch - -FlightDataset: t.TypeAlias = t.Any - -Record: t.TypeAlias = t.Dict[str, t.Any] -RecordList: t.TypeAlias = t.List[Record] -Params: t.TypeAlias = t.Mapping[str, torch.Tensor] - - -class Trainable(t.Protocol): - module: FlightModule - - def get_params(self, include_buffers: bool = False) -> Params: - pass - - def set_params(self, params: Params) -> None: - pass - - def train(self, data: FlightDataset) -> RecordList: - pass - - def test(self): - pass - - def evaluate(self): - pass - - def predict(self): - pass diff --git a/flight/learning/modules/__init__.py b/flight/learning/modules/__init__.py new file mode 100644 index 0000000..772456b --- /dev/null +++ b/flight/learning/modules/__init__.py @@ -0,0 +1,3 @@ +from .base import Trainable + +__all__ = ["Trainable"] diff --git a/flight/learning/modules/base.py b/flight/learning/modules/base.py new file mode 100644 index 0000000..23f4864 --- /dev/null +++ b/flight/learning/modules/base.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import typing as t +from collections import OrderedDict + +from sklearn.neural_network import MLPClassifier, MLPRegressor + +FlightDataset: t.TypeAlias = t.Any +""" +... +""" + +SciKitModule: t.TypeAlias = t.Union[MLPClassifier, MLPRegressor] +""" +Utility type alias for any MLP classifier or regressor implemented in Scikit-Learn. +""" + + +Record: t.TypeAlias = t.Dict[str, t.Any] +""" +... +""" + +RecordList: t.TypeAlias = t.List[Record] +""" +... +""" + +if t.TYPE_CHECKING: + from flight.learning.types import Params + + +@t.runtime_checkable +class Trainable(t.Protocol): + def get_params(self, include_state: bool = False) -> Params: + pass + + def set_params(self, params: Params) -> None: + pass + + +class ScikitTrainable: + WEIGHT_KEY_PREFIX = "weight" + BIAS_KEY_PREFIX = "bias" + + def __init__(self, module: SciKitModule): + self.module = module + + def get_params(self) -> Params: + """ + + Throws: + - ValueError: Occurs when the `len()` of the coefficient and intercept vectors (i.e., `module.coefS_` and + `module.intercepts_`) are not equal. + + Returns: + + """ + num_layers = len(self.module.coefs_) + if num_layers != len(self.module.intercepts_): + raise ValueError( + "ScikitTrainable - Inconsistent number of layers between coefficients/weights and intercepts/biases." + ) + + params = [] + for i in range(num_layers): + params.append((f"{self.WEIGHT_KEY_PREFIX}_{i}", self.module.coefs_[i])) + params.append((f"{self.BIAS_KEY_PREFIX}_{i}", self.module.intercepts_[i])) + + return OrderedDict(params) + + def set_params(self, params: Params): + param_keys = list(params.keys()) + layer_nums = map(lambda txt: int(txt.split("_")[-1]), param_keys) + layer_nums = set(layer_nums) + num_layers = max(layer_nums) + 1 + + weights = [] + biases = [] + for i in range(num_layers): + w_i = params[f"{self.WEIGHT_KEY_PREFIX}_{i}"] + b_i = params[f"{self.BIAS_KEY_PREFIX}_{i}"] + weights.append(w_i) + biases.append(b_i) + + self.module.coefs_ = weights + self.module.intercepts_ = biases diff --git a/flight/learning/modules/torch.py b/flight/learning/modules/torch.py new file mode 100644 index 0000000..43a11e3 --- /dev/null +++ b/flight/learning/modules/torch.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +import abc +import typing as t +from collections import OrderedDict + +import torch +from torch import nn +from torch.utils.data import DataLoader + +if t.TYPE_CHECKING: + from flight.federation.topologies import Node + from flight.learning.types import LocalStepOutput, Params + + +_DEFAULT_INCLUDE_STATE = False + + +# class TorchTrainable: +# def __init__(self, module: torch.nn.Module, include_state: bool = False) -> None: +# self.module = module +# self.include_state = include_state + + +class TorchDataModule(abc.ABC): + # def __init__(self, *args, **kwargs): + # pass + + @abc.abstractmethod + def train_data(self, node: Node | None = None) -> DataLoader: + pass + + # noinspection PyMethodMayBeStatic + def test_data(self, node: Node | None = None) -> DataLoader | None: + return None + + # noinspection PyMethodMayBeStatic + def valid_data(self, node: Node | None = None) -> DataLoader | None: + return None + + +class FlightModule(nn.Module, abc.ABC): + """ + Wrapper class for a PyTorch model (i.e., `torch.nn.Module`). + + Based on PyTorch Lightning's + [LightningModule](https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/core/module.html#LightningModule). + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.include_state = kwargs.get("include_state", _DEFAULT_INCLUDE_STATE) + + @abc.abstractmethod + def training_step(self, *args: t.Any, **kwargs: t.Any) -> LocalStepOutput: + """ + Hello + + Args: + *args: + **kwargs: + + Returns: + + """ + + @abc.abstractmethod + def configure_optimizers(self) -> torch.optim.Optimizer: + """ + Helo + + Returns: + + """ + + def predict_step(self, *args: t.Any, **kwargs: t.Any) -> LocalStepOutput: + """ + Hello + + Args: + *args: + **kwargs: + + Returns: + + """ + raise NotImplementedError() + + def test_step(self, *args: t.Any, **kwargs: t.Any) -> LocalStepOutput: + """ + Hello + + Args: + *args: + **kwargs: + + Returns: + + """ + raise NotImplementedError() + + def validation_step(self, *args: t.Any, **kwargs: t.Any) -> LocalStepOutput: + """ + Hello + + Args: + *args: + **kwargs: + + Returns: + + """ + raise NotImplementedError() + + def get_params(self) -> Params: + params = self.module.state_dict() + if not self.include_state: + params = OrderedDict( + [(name, value.data) for (name, value) in params if value.requires_grad] + ) + + return params + + def set_params(self, params: Params) -> None: + if self.include_state: + self.module.load_state_dict( + params, + strict=True, + assign=False, + ) + else: + self.module.load_state_dict( + params, + strict=False, + assign=False, + ) diff --git a/flight/learning/modules/torch_depr.py b/flight/learning/modules/torch_depr.py new file mode 100644 index 0000000..c2d8bed --- /dev/null +++ b/flight/learning/modules/torch_depr.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import typing as t +from collections import OrderedDict + +if t.TYPE_CHECKING: + from flight.learning.types import Params + + +class TorchTrainableMixins: + def get_params(self, include_state: bool = False) -> Params: + if include_state: + params = {name: value.clone() for name, value in self.state_dict().items()} + else: + params = { + name: param.data.clone() for name, param in self.named_parameters() + } + return OrderedDict(params) + + def set_params(self, params: Params) -> None: + state_dict = self.state_dict() + for name in state_dict: + state_dict[name] = params[name].clone() diff --git a/flight/learning/torch.py b/flight/learning/torch.py deleted file mode 100644 index f9d8220..0000000 --- a/flight/learning/torch.py +++ /dev/null @@ -1,75 +0,0 @@ -import abc -import typing as t - -import torch -from torch import nn - -if t.TYPE_CHECKING: - from .types import LocalStepOutput - - -class FlightModule(nn.Module, abc.ABC): - """ - Wrapper class for a PyTorch model (i.e., `torch.nn.Module`). - - Based on PyTorch Lightning's - [LightningModule](https://lightning.ai/docs/pytorch/stable/_modules/lightning/pytorch/core/module.html#LightningModule). - """ - - @abc.abstractmethod - def training_step(self, *args: t.Any, **kwargs: t.Any) -> LocalStepOutput: - """ - Hello - - Args: - *args: - **kwargs: - - Returns: - - """ - - @abc.abstractmethod - def configure_optimizers(self) -> torch.optim.Optimizer: - """ - Helo - - Returns: - - """ - - def predict_step(self, *args: t.Any, **kwargs: t.Any) -> LocalStepOutput: - """ - Hello - - Args: - *args: - **kwargs: - - Returns: - - """ - - def test_step(self, *args: t.Any, **kwargs: t.Any) -> LocalStepOutput: - """ - Hello - - Args: - *args: - **kwargs: - - Returns: - - """ - - def validation_step(self, *args: t.Any, **kwargs: t.Any) -> LocalStepOutput: - """ - Hello - - Args: - *args: - **kwargs: - - Returns: - - """ diff --git a/flight/learning/trainers/__init__.py b/flight/learning/trainers/__init__.py new file mode 100644 index 0000000..35013d0 --- /dev/null +++ b/flight/learning/trainers/__init__.py @@ -0,0 +1,8 @@ +""" +This module defines the protocol for [`Trainer`][flight.learning.trainers.base.Trainer] objects. + +A `Trainer` is responsible for facilitating the training/fitting, testing, and evaluation of +[`Trainable`][flight.learning.modules.base.Trainable] models. Users are able to use one of the +`Trainer` implementations provided by Flight. Additionally, for more specific use cases, users can +implement their own `Trainer` object by simply implementing the `Trainer` protocol. +""" diff --git a/flight/learning/trainers/base.py b/flight/learning/trainers/base.py new file mode 100644 index 0000000..0ea5865 --- /dev/null +++ b/flight/learning/trainers/base.py @@ -0,0 +1,50 @@ +import typing as t + +from flight.learning.datasets import DataLoadable +from flight.learning.modules import Trainable + + +class Trainer(t.Protocol): + """ + Object class that is responsible for training `Trainable` objects. + """ + + def fit(self, model: Trainable, data: DataLoadable, *args, **kwargs): + """ + fit + + Args: + model (Trainable): + data (DataLoadable): + *args: + **kwargs: + + Returns: + + """ + + def test(self, model: Trainable, *args, **kwargs): + """ + test + + Args: + model: + *args: + **kwargs: + + Returns: + + """ + + def validate(self, model: Trainable, data: DataLoadable, *args, **kwargs): + """ + evaluate + + Args: + model: + *args: + **kwargs: + + Returns: + + """ diff --git a/flight/learning/trainers/lightning.py b/flight/learning/trainers/lightning.py new file mode 100644 index 0000000..d1cbbf8 --- /dev/null +++ b/flight/learning/trainers/lightning.py @@ -0,0 +1,10 @@ +import lightning.pytorch as L + + +class LightningTrainer: + def __init__(self, *args, **kwargs) -> None: + self.trainer = L.Trainer(*args, **kwargs) + + def fit(self, *args, **kwargs): + results = self.trainer(*args, **kwargs) + return results diff --git a/flight/learning/trainers/scikit.py b/flight/learning/trainers/scikit.py new file mode 100644 index 0000000..481a567 --- /dev/null +++ b/flight/learning/trainers/scikit.py @@ -0,0 +1,17 @@ +from flight.learning.datasets import DataLoadable +from flight.learning.modules.base import SciKitModule + + +class ScikitTrainer: + def __init__(self, partial: bool = True): + self.partial = partial + + def fit(self, model: SciKitModule, data: DataLoadable): + inputs, targets = data.load() + model.fit(inputs, targets) + + def test(self): + pass + + def validate(self): + pass diff --git a/flight/learning/trainers/torch.py b/flight/learning/trainers/torch.py new file mode 100644 index 0000000..ad26fae --- /dev/null +++ b/flight/learning/trainers/torch.py @@ -0,0 +1,149 @@ +from __future__ import annotations + +import typing as t + +import torch +from torch.utils.data import DataLoader + +from ..modules.torch import FlightModule, TorchDataModule + +if t.TYPE_CHECKING: + from ...federation.topologies.node import Node, WorkerState + from ...strategies.trainer import TrainerStrategy + + EVAL_DATALOADERS: t.TypeAlias = t.Any # + TRAIN_DATALOADERS: t.TypeAlias = t.Any # + _PATH: t.TypeAlias = t.Any # + LightningDataModule: t.TypeAlias = t.Any # + + +class TorchTrainer: + def __init__(self, node: Node, strategy: TrainerStrategy, max_epochs: int): + self.node = node + self.strategy = strategy + self.max_epochs = max_epochs + self._device = torch.device(node.extra.get("device", "cpu")) + # self.logger = + + def fit( + self, + node_state: WorkerState, + model: FlightModule, + data: TorchDataModule, + validate_every_n_epochs: int = 1, + # train_dataloaders: TRAIN_DATALOADERS | LightningDataModule | None = None, + # val_dataloaders: EVAL_DATALOADERS | None = None, + # datamodule: LightningDataModule | None = None, + ckpt_path: _PATH | None = None, + ): + """ + + Args: + node_state: + model: + data: + validate_every_n_epochs: + ckpt_path: + + Raises: + - ValueError: Thrown when illegal values are given to arguments. + + Returns: + + """ + # TODO: Run the argument validation in a separate utility function to keep this function light. + if validate_every_n_epochs < 1: + raise ValueError("Illegal value for argument `validate_every_n_epochs`.") + + model.to(self._device) + + results = [] + + train_dataloader = data.train_data(self.node) + valid_dataloader = data.valid_data(self.node) + + if not isinstance(train_dataloader, DataLoader): + raise TypeError( + "Method for argument `data.train_data(.)` must return a `DataLoader`." + ) + if not isinstance(valid_dataloader, DataLoader): + raise TypeError( + "Method for argument `data.valid_data(.)` must return a `DataLoader`." + ) + + optimizer = model.configure_optimizers() + + for epoch in range(self.max_epochs): + print(f"❯ Running epoch {epoch} out of {self.max_epochs}.") + train_losses = self._epoch( + node_state, + model, + optimizer, + train_dataloader, + # train_dataloaders, + ) + for l in train_losses: + results.append({"epoch": epoch, "train/loss": l.item()}) + + to_validate = all( + [epoch % validate_every_n_epochs == 0, valid_dataloader is not None] + ) + if to_validate: + val_losses = self.validate(model, valid_dataloader) + for l in val_losses: + results.append({"epoch": epoch, "val/loss": l.item()}) + + return results + + def _epoch( + self, + node_state: WorkerState, + model: FlightModule, + optimizer: torch.optim.Optimizer, + dataloader: DataLoader, + ): + self._set_train_mode(model, True) + + losses = [] + for batch_idx, batch in enumerate(dataloader): + batch = self._batch_to_device(batch) + loss = model.training_step(batch, batch_idx) + + # Perform backpropagation and call trainer strategy callbacks. + optimizer.zero_grad() + loss = self.strategy.before_backprop(node_state, loss) + loss.backward() + loss = self.strategy.after_backprop(node_state, loss) + losses.append(loss) + optimizer.step() + + return losses + + def validate(self, model: FlightModule, dataloader: DataLoader, *args, **kwargs): + self._set_train_mode(model, False) + + losses = [] + for batch_idx, batch in enumerate(dataloader): + batch = self._batch_to_device(batch) + loss = model.validation_step(batch, batch_idx) + losses.append(loss) + + return losses + + def _batch_to_device(self, batch: tuple[t.Any, ...]): + items = [] + for item in batch: + try: + item = item.to(self._device) + except AttributeError: + pass + items.append(item) + return tuple(items) + # return tuple(item.to(self._device) for item in batch) + + def _set_train_mode(self, model, mode: True): + torch.set_grad_enabled(mode) + if mode: + model.train() + else: + model.eval() diff --git a/flight/learning/types.py b/flight/learning/types.py index bb52985..f323152 100644 --- a/flight/learning/types.py +++ b/flight/learning/types.py @@ -1,5 +1,6 @@ import typing as t -import torch +from torch import Tensor -LocalStepOutput: t.TypeAlias = t.Optional[torch.Tensor | t.Mapping[str, t.Any]] +LocalStepOutput: t.TypeAlias = t.Optional[Tensor | t.Mapping[str, t.Any]] +Params: t.TypeAlias = t.Mapping[str, Tensor] diff --git a/flight/strategies/__init__.py b/flight/strategies/__init__.py index ef5f39f..131036b 100644 --- a/flight/strategies/__init__.py +++ b/flight/strategies/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import typing as t import torch @@ -8,9 +10,8 @@ from flight.strategies.trainer import TrainerStrategy from flight.strategies.worker import WorkerStrategy -Loss: t.TypeAlias = torch.Tensor -Params: t.TypeAlias = dict[str, torch.Tensor] -NodeState: t.TypeAlias = t.Any +if t.TYPE_CHECKING: + Loss: t.TypeAlias = torch.Tensor def load_strategy(strategy_name: str, **kwargs) -> Strategy: diff --git a/flight/strategies/commons/averaging.py b/flight/strategies/commons/averaging.py index eb44c44..1d159eb 100644 --- a/flight/strategies/commons/averaging.py +++ b/flight/strategies/commons/averaging.py @@ -5,12 +5,12 @@ import numpy import torch +from flight.federation.topologies.node import NodeID +from flight.learning.types import Params + if t.TYPE_CHECKING: from collections.abc import Mapping - NodeID: t.TypeAlias = t.Any - Params: t.TypeAlias = t.Any - def average_state_dicts( state_dicts: Mapping[NodeID, Params], weights: Mapping[NodeID, float] | None = None diff --git a/flight/strategies/trainer.py b/flight/strategies/trainer.py index 236b571..c42bb93 100644 --- a/flight/strategies/trainer.py +++ b/flight/strategies/trainer.py @@ -3,7 +3,6 @@ import typing as t if t.TYPE_CHECKING: - from flight.strategies import Loss, NodeState diff --git a/flight/types.py b/flight/types.py new file mode 100644 index 0000000..3767f1f --- /dev/null +++ b/flight/types.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +import typing as t + +Record: t.TypeAlias = dict[str, t.Any] diff --git a/flox/learn/trainer.py b/flox/learn/trainer.py index f4b19c9..14c2fd1 100644 --- a/flox/learn/trainer.py +++ b/flox/learn/trainer.py @@ -12,8 +12,8 @@ from flox.learn.logger import ModelLogger if t.TYPE_CHECKING: - from flox.federation.topologies import WorkerState from flox.strategies import TrainerStrategy + from flox.topos.types import WorkerState class Trainer: @@ -49,7 +49,7 @@ def fit( model.to(self.device) with torch.set_grad_enabled(True): for epoch in range(num_epochs): - _ = self._epoch( + avg_loss = self._epoch( epoch, model, node_state, @@ -72,7 +72,7 @@ def fit( self.validate(model, valid_dataloader, epoch, valid_ckpt_path) # model.to("cpu") - return self.logger.dataframe() + return self.logger.to_pandas() def test( self, @@ -94,7 +94,7 @@ def validate( model.eval() with torch.no_grad(): for batch_idx, batch in enumerate(valid_dataloader): - _ = model.validation_step(batch, batch_idx) + loss = model.validation_step(batch, batch_idx) # self.logger.log_dict( # { # "valid/loss": loss.item(), @@ -114,21 +114,6 @@ def _epoch( valid_ckpt_path: Path | str | None = None, valid_dataloader: DataLoader | None = None, ): - """ - - Args: - epoch_index: - model: - node_state: - optimizer: - train_dataloader: - valid_ckpt_path: - valid_dataloader: - - Returns: - The average loss of the epoch. - """ - def log_condition(batch_idx: int): conditions = [ batch_idx % self.log_every_n_batches == self.log_every_n_batches - 1, diff --git a/flox_examples/aggr_comparison.py b/flox_examples/aggr_comparison.py index 2ccae9a..70243e9 100644 --- a/flox_examples/aggr_comparison.py +++ b/flox_examples/aggr_comparison.py @@ -13,19 +13,19 @@ with warnings.catch_warnings(): warnings.simplefilter("ignore") - import flox import os + import pandas as pd import torchvision.transforms as transforms - + from models import * from torchvision.datasets import FashionMNIST + import flox + from flox import Topology + from flox.data import FloxDataset from flox.data.utils import federated_split from flox.federation.topologies import hierarchical_topology from flox.strategies import load_strategy - from flox import Topology - from flox.data import FloxDataset - from models import * def train_experiment( diff --git a/flox_examples/async_demo.py b/flox_examples/async_demo.py index c438e1c..c52aae1 100644 --- a/flox_examples/async_demo.py +++ b/flox_examples/async_demo.py @@ -15,18 +15,18 @@ with warnings.catch_warnings(): warnings.simplefilter("ignore") - import flox import os + import pandas as pd import torchvision.transforms as transforms - + from models import * from torchvision.datasets import FashionMNIST - from flox.data.utils import federated_split - from flox.strategies import load_strategy + import flox from flox import Topology from flox.data import FloxDataset - from models import * + from flox.data.utils import federated_split + from flox.strategies import load_strategy def train_experiment( diff --git a/flox_examples/async_test_demo.py b/flox_examples/async_test_demo.py index 01ec36b..2c6e740 100644 --- a/flox_examples/async_test_demo.py +++ b/flox_examples/async_test_demo.py @@ -7,16 +7,16 @@ import pandas as pd import torch import torchmetrics -from flox.data import federated_split from torch import nn from torch.nn import functional as F from torchvision import transforms from torchvision.datasets import FashionMNIST import flox +from flox.data import federated_split +from flox.federation.topologies import two_tier_topology from flox.learn import FloxModule from flox.strategies import load_strategy -from flox.federation.topologies import two_tier_topology logging.basicConfig( format="(%(levelname)s - %(asctime)s) ❯ %(message)s", level=logging.INFO diff --git a/flox_examples/simple_hierarchical.py b/flox_examples/simple_hierarchical.py index 80f4aa2..958c44d 100644 --- a/flox_examples/simple_hierarchical.py +++ b/flox_examples/simple_hierarchical.py @@ -7,17 +7,17 @@ with warnings.catch_warnings(): warnings.simplefilter("ignore") - import flox import os - import pandas as pd + import pandas as pd + from models import * from torchvision.datasets import FashionMNIST from torchvision.transforms import ToTensor + import flox from flox.data.utils import federated_split from flox.federation.topologies import create_hier_flock from flox.strategies import load_strategy - from models import * def main(): diff --git a/flox_examples/test.py b/flox_examples/test.py index 894c24e..111d368 100644 --- a/flox_examples/test.py +++ b/flox_examples/test.py @@ -2,22 +2,22 @@ from pathlib import Path import torch -from flox.data.utils import federated_split from torch import nn from torch.nn import functional as F from torchvision import transforms from torchvision.datasets import FashionMNIST import flox -from flox.learn import FloxModule +from flox.data.utils import federated_split from flox.federation.topologies import two_tier_topology +from flox.learn import FloxModule class Net(FloxModule): """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')""" def __init__(self) -> None: - super(Net, self).__init__() + super().__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) diff --git a/flox_examples/topo_experiment.py b/flox_examples/topo_experiment.py index 062a29e..238f56a 100644 --- a/flox_examples/topo_experiment.py +++ b/flox_examples/topo_experiment.py @@ -12,15 +12,15 @@ with warnings.catch_warnings(): warnings.simplefilter("ignore") - import flox import os - import torchvision.transforms as transforms + import torchvision.transforms as transforms + from models import * from torchvision.datasets import FashionMNIST - from flox.federation.topologies import hierarchical_topology + import flox from flox.data import federated_split - from models import * + from flox.federation.topologies import hierarchical_topology def load_data() -> FashionMNIST: diff --git a/flox_examples/yadu_async_test.py b/flox_examples/yadu_async_test.py index 114456c..72fdcff 100644 --- a/flox_examples/yadu_async_test.py +++ b/flox_examples/yadu_async_test.py @@ -11,7 +11,6 @@ from flox.federation.topologies import two_tier_topology from flox.strategies import load_strategy - # from flox_classes import Net, KyleNet diff --git a/mkdocs.yml b/mkdocs.yml index c7251dd..bfa10e7 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,14 +1,13 @@ -site_name: FLoX +site_name: Flight #site_url: https://flox.dev/ -repo_name: nathaniel-hudson/FLoX -repo_url: https://github.com/nathaniel-hudson/FLoX +repo_name: h-flox/FLoX +repo_url: https://github.com/h-flox/FLoX ######################################################################################################################## # NAVIGATION ######################################################################################################################## - nav: - Home: - index.md @@ -27,19 +26,18 @@ nav: - Publications: publications/index.md - Docs: docs/ +copyright: + Copyright © 2022 - 2024 Globus Labs, University of Chicago. ######################################################################################################################## ######################################################################################################################## ######################################################################################################################## -copyright: - Copyright © 2022 - 2024 Globus Labs, University of Chicago. - watch: - - mkdocs.yml - - README.md - docs/ - flight/ + - mkdocs.yml + - README.md extra: social: @@ -93,7 +91,7 @@ theme: toggle: icon: material/brightness-4 name: Switch to light mode - primary: black # indigo + primary: indigo # black accent: amber diff --git a/tests/federation/topologies/test_node.py b/tests/federation/topologies/test_node.py index 22b7845..89ba004 100644 --- a/tests/federation/topologies/test_node.py +++ b/tests/federation/topologies/test_node.py @@ -1,31 +1,71 @@ +import pytest from pydantic import ValidationError -from flight.federation.topologies.node import Node, NodeKind +from flight.federation.topologies.node import ( + AggrState, + Node, + NodeKind, + NodeState, + WorkerState, +) -def test_node_inits(): +class TestNodeInits: + @staticmethod def mini_test(should_work: bool, **kwargs): - try: - Node(**kwargs) - assert should_work - except ValidationError: - assert not should_work - - mini_test( - True, - idx=123, - kind=NodeKind.WORKER, - extra={"battery_cap": 10}, - ) - - mini_test( - False, - extra={"battery_cap": 10}, - ) - - mini_test( - False, - idx=10, - kind="hello", - extra={"battery_cap": 10}, - ) + if should_work: + node = Node(**kwargs) + assert isinstance(node, Node) + else: + with pytest.raises(ValidationError): + Node(**kwargs) + # try: + # Node(**kwargs) + # assert should_work + # except ValidationError: + # assert not should_work + + def test_valid_inits(self): + should_init = True + TestNodeInits.mini_test( + should_init, + idx=123, + kind=NodeKind.WORKER, + extra={"battery_cap": 10}, + ) + + TestNodeInits.mini_test( + should_init, + idx=123, + kind="worker", + extra={"battery_cap": 10}, + ) + + def test_invalid_inits(self): + should_init = False + TestNodeInits.mini_test( + should_init, + extra={"battery_cap": 10}, + ) + + TestNodeInits.mini_test( + should_init, + idx=10, + kind="hello", + extra={"battery_cap": 10}, + ) + + +class TestNodeState: + def test_state_init(self): + with pytest.raises(TypeError): + NodeState(1) + + children = [Node(idx=1, kind=NodeKind.WORKER), Node(idx=2, kind=NodeKind.AGGR)] + state = AggrState(1, children) + assert isinstance(state, NodeState) + assert isinstance(state, AggrState) + + state = WorkerState(1) + assert isinstance(state, NodeState) + assert isinstance(state, WorkerState) diff --git a/tests/learning/test_model.py b/tests/learning/test_model.py index 65dae86..e59d9e5 100644 --- a/tests/learning/test_model.py +++ b/tests/learning/test_model.py @@ -1,7 +1,8 @@ import pytest import torch -from flight.learning.torch import FlightModule +from flight.learning.modules.base import Trainable +from flight.learning.modules.torch import FlightModule @pytest.fixture @@ -28,7 +29,7 @@ def configure_optimizers(self): @pytest.fixture def invalid_module(): - class TestModule(FlightModule): # noqa + class TestModule(FlightModule): # noqa def __init__(self): super().__init__() self.model = torch.nn.Sequential( @@ -43,15 +44,25 @@ def forward(self, x): class TestModelInit: - def test_1(self, valid_module): + def test_valid_model_init(self, valid_module): model = valid_module() assert isinstance(model, FlightModule) + assert isinstance(model, Trainable) assert isinstance(model, torch.nn.Module) x = torch.tensor([[1.0]]) y = model(x) assert isinstance(y, torch.Tensor) - def test_2(self, invalid_module): + def test_invalid_model_init(self, invalid_module): with pytest.raises(TypeError): invalid_module() + + def test_model_get_params(self, valid_module): + model = valid_module() + try: + _ = model.get_params(include_state=False) + params = model.get_params(include_state=True) + model.set_params(params) + except Exception as exc: + pytest.fail(exc, "Unexpected error/exception.") diff --git a/tests/learning/test_trainer.py b/tests/learning/test_trainer.py new file mode 100644 index 0000000..5316a3b --- /dev/null +++ b/tests/learning/test_trainer.py @@ -0,0 +1,135 @@ +import pytest +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader, TensorDataset + +from flight.federation.topologies.node import Node, WorkerState +from flight.learning.datasets import DataLoadable +from flight.learning.modules.torch import FlightModule, TorchDataModule +from flight.learning.trainers.torch import TorchTrainer +from flight.strategies.base import DefaultTrainerStrategy + +NUM_FEATURES = 10 + + +@pytest.fixture +def node() -> Node: + node = Node(idx=0, kind="worker") + return node + + +@pytest.fixture +def worker_state() -> WorkerState: + return WorkerState(0, None, None) + + +@pytest.fixture +def module_cls() -> type[FlightModule]: + class MyModule(FlightModule): + def __init__(self): + super().__init__() + self.model = torch.nn.Sequential( + torch.nn.Linear(NUM_FEATURES, 100), + torch.nn.Linear(100, 1), + ) + + def forward(self, x): + return self.model(x) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=1e-3) + + def training_step(self, batch, batch_nb): + inputs, targets = batch + preds = self(inputs) + return F.l1_loss(preds, targets) + + return MyModule + + +@pytest.fixture +def data_cls() -> type[DataLoadable]: + class MyDataLoadable(DataLoadable): + def __init__(self): + torch.manual_seed(0) + n = 1000 # total number of data samples + f = NUM_FEATURES # number of features + n_train, n_test, n_valid = n * 0.8, n * 0.1, n * 0.1 + n_train, n_test, n_valid = int(n_train), int(n_test), int(n_valid) + self.train = TensorDataset( + torch.randn((n_train, f)), + torch.randn((n_train, 1)), + ) + self.test = TensorDataset( + torch.randn((n_test, f)), + torch.randn((n_test, 1)), + ) + self.valid = TensorDataset( + torch.randn((n_valid, f)), + torch.randn((n_valid, 1)), + ) + + def load(self, node: Node, mode: str): + match mode: + case "train": + return self.train + case "test": + return self.test + case "valid" | "validation": + return self.valid + case _: + raise ValueError("Illegal `mode` literal value.") + + return MyDataLoadable + + +class TestTrainer: + def _default_torch_trainer(self, node, worker_state, module_cls, data_cls): + """Tests a basic setup of using the `TorchTrainer` class for PyTorch-based models.""" + model = module_cls() + data = data_cls() + trainer = TorchTrainer(node, DefaultTrainerStrategy(), 1) + assert isinstance(model, FlightModule) + assert isinstance(trainer, TorchTrainer) + + trainer.fit(worker_state, model, data) + + def test_node_device_specifier(self, node): + """Confirms that the device""" + trainer = TorchTrainer(node, DefaultTrainerStrategy(), 1) + assert str(trainer._device) == "cpu" + + node.extra["device"] = "cuda" + trainer = TorchTrainer(node, DefaultTrainerStrategy(), 1) + assert str(trainer._device) == "cuda" + + node.extra["device"] = "mps" + trainer = TorchTrainer(node, DefaultTrainerStrategy(), 1) + assert str(trainer._device) == "mps" + + def test_temp(self, node): + class Foo(TorchDataModule): + def __init__(self): + super().__init__() + torch.manual_seed(0) + n = 1000 # total number of data samples + f = NUM_FEATURES # number of features + n_train, n_test, n_valid = n * 0.8, n * 0.1, n * 0.1 + n_train, n_test, n_valid = int(n_train), int(n_test), int(n_valid) + self.train = TensorDataset( + torch.randn((n_train, f)), + torch.randn((n_train, 1)), + ) + self.test = TensorDataset( + torch.randn((n_test, f)), + torch.randn((n_test, 1)), + ) + self.valid = TensorDataset( + torch.randn((n_valid, f)), + torch.randn((n_valid, 1)), + ) + + def train_data(self, node: Node | None = None): + return DataLoader(self.train, batch_size=8) + + assert isinstance(Foo().train_data(node), DataLoader)