Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
nathaniel-hudson committed Nov 15, 2024
1 parent 5de98db commit 46b3432
Show file tree
Hide file tree
Showing 14 changed files with 202 additions and 170 deletions.
23 changes: 13 additions & 10 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
NUM_LABELS = 10


class TrainingModule(TorchModule):
class MyModule(TorchModule):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
Expand All @@ -35,7 +35,7 @@ def forward(self, x):
def training_step(self, batch, batch_idx) -> TensorLoss:
x, y = batch
y_hat = self(x)
return nn.functional.nll_loss(y_hat, y)
return nn.functional.cross_entropy(y_hat, y)

def configure_optimizers(self) -> torch.optim.Optimizer:
return torch.optim.Adam(self.parameters(), lr=0.02)
Expand All @@ -48,24 +48,27 @@ def main():
train=False,
transform=ToTensor(),
)
data = Subset(data, indices=list(range(2000)))
data = Subset(data, indices=list(range(200)))
topo = fl.flat_topology(10)
# exit(0)
module = TrainingModule()
module = MyModule()
fed_data = federated_split(
topo=topo,
# data=TensorDataset(
# torch.randn(100, 1), torch.randint(low=0, high=NUM_LABELS, size=(100, 1))
# ),
data=data,
num_labels=NUM_LABELS,
label_alpha=100.0,
sample_alpha=100.0,
)
trained_module, records = fl.federated_fit(topo, module, fed_data, rounds=2)
trained_module, records = fl.federated_fit(topo, module, fed_data, rounds=10)

df = pd.DataFrame.from_records(records)
sns.lineplot(df, x="round", y="train/loss")
print(df.head())
sns.lineplot(
df,
x="train/time",
y="train/loss",
hue="node/idx",
# errorbar=None,
).set(yscale="linear")
plt.show()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import globus_compute_sdk

from ...federation.topologies import Node
from .base import AbstractController
from flight.engine.controllers.base import AbstractController
from flight.federation.topologies import Node

if t.TYPE_CHECKING:
from flight.types import P, T
Expand Down
2 changes: 1 addition & 1 deletion flight/federation/jobs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@

from flight.federation.topologies.node import Node, NodeState, WorkerState
from flight.learning.base import AbstractDataModule, AbstractModule
from flight.learning.types import Params

if t.TYPE_CHECKING:
from flight.types import Record
from flight.engine.transporters import AbstractTransporter
from flight.learning.params import Params
from flight.strategies import AggrStrategy, TrainerStrategy, WorkerStrategy


Expand Down
2 changes: 1 addition & 1 deletion flight/learning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
"""

from .base import AbstractDataModule, AbstractModule, AbstractTrainer
from .params import NpParams, TorchParams
from .torch.utils import federated_split
from .types import NpParams, Params, TorchParams

__all__ = [
"AbstractModule",
Expand Down
9 changes: 3 additions & 6 deletions flight/learning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
if t.TYPE_CHECKING:
from ..federation.topologies import Node
from ..types import Record
from .types import Data, DataIterable, DataKinds, FrameworkKind, Params
from .params import Params
from .types import Data, DataIterable, DataKinds, FrameworkKind


# DataType = t.TypeVar("DataType", bound="AbstractDataModule")
Expand Down Expand Up @@ -83,14 +84,10 @@ def size(self, node: Node | None = None, kind: DataKinds = "train") -> int | Non

class AbstractModule(abc.ABC):
@abc.abstractmethod
def get_params(self, to_numpy: bool = True) -> Params:
def get_params(self) -> Params:
"""
Getter method for the parameters of a trainable module (i.e., neural network).
Args:
to_numpy (bool): Flag to convert the parameters to numpy `ndarray`s.
Defaults to `True`.
Returns:
The parameters of the module.
"""
Expand Down
100 changes: 74 additions & 26 deletions flight/learning/params.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,30 @@
from __future__ import annotations

import functools
import typing as t
from collections import OrderedDict
from enum import Enum, auto

import numpy as np
import torch

from flight.learning import NpParams, TorchParams
NpParams: t.TypeAlias = dict[str, np.ndarray]
"""
Type alias for model parameters as a mapping where the keys are strings and
the values are Numpy `ndarray`s.
"""

TorchParams: t.TypeAlias = dict[str, torch.Tensor]
"""
Type alias for model parameters as a mapping where the keys are strings and
the values are parameters as PyTorch `Tensor`s.
"""


class UnsupportedParameterKindError(ValueError):
"""An Exception raised when an unsupported parameter kind is detected."""
"""
An Exception raised when an unsupported parameter kind is detected.
"""

def __init__(self, message: str | None = None, *args):
if message is None:
Expand All @@ -20,7 +36,9 @@ def __init__(self, message: str | None = None, *args):


class InconsistentParamValuesError(ValueError):
"""An Exception raised when the parameter value kinds are inconsistent."""
"""
An Exception raised when the parameter value kinds are inconsistent.
"""

def __init__(self, message: str | None = None, *args):
if message is None:
Expand All @@ -29,8 +47,19 @@ def __init__(self, message: str | None = None, *args):


class ParamKinds(Enum):
"""
An enumeration of the kinds of parameters supported by Flight.
"""

NUMPY = auto()
"""
Parameters implemented as NumPy `ndarray`s.
"""

TORCH = auto()
"""
Parameters implemented as PyTorch `Tensor`s.
"""


def infer_param_kind(param: t.Any) -> ParamKinds:
Expand All @@ -44,7 +73,7 @@ def infer_param_kind(param: t.Any) -> ParamKinds:
The kind of parameter.
Throws:
- `ValueError`: If the parameter kind is unknown or unsupported.
- `UnsupportedParameterKindError`: If the parameter kind is unknown/unsupported.
"""
if isinstance(param, np.ndarray):
return ParamKinds.NUMPY
Expand All @@ -70,40 +99,59 @@ def validate_param_kind(params: dict[str, t.Any]) -> ParamKinds:
Throws:
- `InconsistentParamValuesError`: If the parameter values are inconsistent.
- `UnsupportedParameterKindError`: If the parameter kind is unknown/unsupported.
This will be thrown by the `infer_param_kind` function.
"""
param_kinds = set(map(infer_param_kind, params.values()))
if len(param_kinds) != 1:
raise InconsistentParamValuesError()
return param_kinds.pop()


class Params:
def __init__(self, raw_params: dict[str, t.Any]):
self._raw_params = raw_params
self._inferred_kind = validate_param_kind(raw_params)
class Params(OrderedDict):
"""
A wrapper class for model parameters, implemented as an `OrderedDict`.
Throws:
- `InconsistentParamValuesError`: If the parameter values are inconsistent.
- `UnsupportedParameterKindError`: If the parameter kind is unknown/unsupported.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def numpy(self) -> NpParams:
match self._inferred_kind:
"""
Convert the parameters to NumPy `ndarray`s.
Returns:
The parameters in NumPy `ndarray`s.
"""
match self.inferred_kind:
case ParamKinds.NUMPY:
return self._raw_params
return self
case ParamKinds.TORCH:
return {k: v.numpy() for k, v in self._raw_params.items()}
return OrderedDict((k, v.numpy()) for k, v in self.items())

def torch(self) -> TorchParams:
match self._inferred_kind:
"""
Convert the parameters to PyTorch `Tensor`s.
Returns:
The parameters in the PyTorch `Tensor`s.
"""
match self.inferred_kind:
case ParamKinds.TORCH:
return self._raw_params
return self
case ParamKinds.NUMPY:
return {k: torch.from_numpy(v) for k, v in self._raw_params.items()}


# class NpParams(Params):
# @abc.abstractmethod
# def numpy(self) -> dict[str, npt.NDArray]:
# pass
#
#
# class TorchParams(Params):
# @abc.abstractmethod
# def numpy(self) -> dict[str, npt.NDArray]:
# pass
return OrderedDict((k, torch.from_numpy(v)) for k, v in self.items())

@functools.cached_property
def inferred_kind(self) -> ParamKinds:
"""
The inferred kind of the parameters.
Returns:
The kind of parameters.
"""
return validate_param_kind(self)
35 changes: 24 additions & 11 deletions flight/learning/scikit/module.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from __future__ import annotations

import typing as t
from collections import OrderedDict

from sklearn.neural_network import MLPClassifier, MLPRegressor

from flight.learning import AbstractModule
from flight.learning.types import FrameworkKind, Params
from flight.learning.params import Params
from flight.learning.types import FrameworkKind

if t.TYPE_CHECKING:
from sklearn.neural_network import MLPClassifier, MLPRegressor


class ScikitModule(AbstractModule):
WEIGHT_KEY_PREFIX = "weight"
BIAS_KEY_PREFIX = "bias"
WEIGHT_KEY_PREFIX: t.Final[str] = "weight"
BIAS_KEY_PREFIX: t.Final[str] = "bias"

def __init__(self, module: MLPClassifier | MLPRegressor):
self.module = module
Expand All @@ -24,20 +25,33 @@ def __init__(self, module: MLPClassifier | MLPRegressor):
def kind(self) -> FrameworkKind:
return "scikit"

def get_params(self, _: bool = True) -> Params:
def get_params(self) -> Params:
"""
Getter method for the parameters of a trainable module (i.e., neural network)
Returns:
The parameters of the module.
"""
params = []
for i in range(self._n_layers):
params.append((f"{self.WEIGHT_KEY_PREFIX}_{i}", self.module.coefs_[i]))
params.append((f"{self.BIAS_KEY_PREFIX}_{i}", self.module.intercepts_[i]))
return OrderedDict(params)
return Params(params)

def set_params(self, params: Params):
"""
Setter method for the parameters of a trainable module (i.e., neural network)
implemented in Scikit-Learn.
Args:
params (Params): The parameters to set.
"""
params = params.numpy()
param_keys = list(params.keys())
layer_nums = set(map(lambda txt: int(txt.split("_")[-1]), param_keys))
num_layers = max(layer_nums) + 1

weights = []
biases = []
weights, biases = [], []
for i in range(num_layers):
w_i = params[f"{self.WEIGHT_KEY_PREFIX}_{i}"]
b_i = params[f"{self.BIAS_KEY_PREFIX}_{i}"]
Expand All @@ -57,5 +71,4 @@ def _n_layers(self) -> int:
"ScikitModule :: Inconsistent number of layers between "
"coefficients/weights and intercepts/biases."
)

return n
Loading

0 comments on commit 46b3432

Please sign in to comment.