Skip to content

Commit

Permalink
Down to only 6 mypy errors in 4 files
Browse files Browse the repository at this point in the history
  • Loading branch information
nathaniel-hudson committed Feb 27, 2024
1 parent 511a66c commit d999d07
Show file tree
Hide file tree
Showing 38 changed files with 342 additions and 301 deletions.
67 changes: 21 additions & 46 deletions examples/notebooks/Quickstart.ipynb

Large diffs are not rendered by default.

7 changes: 2 additions & 5 deletions examples/notebooks/Testing Plots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2023-11-08T16:31:05.227580Z",
"start_time": "2023-11-08T16:31:05.225559Z"
}
"is_executing": true
},
"outputs": [],
"source": [
Expand Down
17 changes: 12 additions & 5 deletions flox/data/core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
from __future__ import annotations

import abc
from typing import TypeVar, Iterable
import typing

from torch.utils.data import Dataset, Subset

from flox.flock import FlockNodeID, FlockNode
from flox.flock.states import NodeState
from flox.flock import FlockNode

if typing.TYPE_CHECKING:
from typing import TypeVar, Iterator

from flox.flock.states import NodeState
from flox.flock import FlockNodeID

T_co = TypeVar("T_co", covariant=True)
T_co = TypeVar("T_co", covariant=True)


class FloxDataset(abc.ABC):
Expand Down Expand Up @@ -48,7 +55,7 @@ def __getitem__(self, node: FlockNode | FlockNodeID) -> Subset[T_co]:
def __len__(self):
return self._num_subsets

def __iter__(self) -> Iterable[tuple[FlockNodeID, Subset[T_co]]]:
def __iter__(self) -> Iterator[tuple[FlockNodeID, Subset[T_co]]]:
for idx in self.indices:
yield idx, self.load(idx)

Expand Down
32 changes: 21 additions & 11 deletions flox/data/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import warnings
from collections import defaultdict
from collections import defaultdict, Counter

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from scipy import stats
from torch.utils.data import Dataset, DataLoader

from flox.data import FederatedSubsets
from flox.flock import Flock
from flox.flock import Flock, FlockNodeID


# TODO: Implement something similar for regression-based data.
Expand Down Expand Up @@ -57,19 +56,30 @@ def federated_split(
assert labels_alpha > 0

num_workers = len(list(flock.workers))
sample_distr = stats.dirichlet(np.full(num_workers, samples_alpha))
label_distr = stats.dirichlet(np.full(num_classes, labels_alpha))

# pytorch intentionally doesn't define an empty __len__ for DataSet, even though
# most subclasses implement it
data_count = len(data) # type: ignore
# sample_distr = stats.dirichlet(np.full(num_workers, samples_alpha))
# label_distr = stats.dirichlet(np.full(num_classes, labels_alpha))

s_alpha = np.full(num_workers, samples_alpha)
sample_distr = np.random.dirichlet(s_alpha)

l_alpha = np.full(num_classes, labels_alpha)
label_distr = np.random.dirichlet(l_alpha, size=flock.number_of_workers)

# PyTorch intentionally doesn't define an empty __len__ for ``Dataset``, even though
# most subclasses implement it.
try:
data_count = len(data) # type: ignore
except NotImplementedError:
raise NotImplementedError(
"Provided ``Dataset`` does not override ``__len__``, which is required for ``federated_split()``."
)

num_samples_for_workers = (sample_distr.rvs()[0] * data_count).astype(int)
num_samples_for_workers = (sample_distr * data_count).astype(int)
num_samples_for_workers = {
worker.idx: num_samples
for worker, num_samples in zip(flock.workers, num_samples_for_workers)
}
label_probs = {w.idx: label_distr.rvs()[0] for w in flock.workers}
label_probs = {w.idx: label_distr[i] for i, w in enumerate(flock.workers)}

indices: dict[FlockNodeID, list[int]] = defaultdict(list)
loader = DataLoader(data, batch_size=1)
Expand Down
10 changes: 4 additions & 6 deletions flox/flock/flock.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from __future__ import annotations

import functools
import json
from pathlib import Path
from typing import Any, Generator, Iterator
from typing import Any, Iterator
from uuid import UUID

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -317,7 +315,7 @@ def from_dict(content: dict[str, Any], _src: Path | str | None = None) -> "Flock
return Flock(topo=topo, _src=_src)

@staticmethod
def from_json(path: Path | str) -> Flock:
def from_json(path: Path | str) -> "Flock":
"""Imports a .json file as a Flock.
Examples:
Expand All @@ -335,7 +333,7 @@ def from_json(path: Path | str) -> Flock:
return Flock.from_dict(content, _src=path)

@staticmethod
def from_yaml(path: Path | str) -> Flock:
def from_yaml(path: Path | str) -> "Flock":
"""Imports a `.yaml` file as a Flock.
Examples:
Expand Down Expand Up @@ -443,7 +441,7 @@ def two_tier(self) -> bool:
return False
return True

def nodes(self, by_kind: FlockNodeKind | None = None) -> Generator[FlockNode]:
def nodes(self, by_kind: FlockNodeKind | None = None) -> Iterator[FlockNode]:
for idx, data in self.topo.nodes(data=True):
if by_kind is not None and data["kind"] != by_kind:
continue
Expand Down
25 changes: 14 additions & 11 deletions flox/flock/states.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from dataclasses import field
from typing import Any, Iterable
from __future__ import annotations

import torch
import typing
from dataclasses import dataclass, field

from flox.flock import FlockNodeID
if typing.TYPE_CHECKING:
from typing import Any, Iterable

from flox.flock import FlockNodeID
from flox.nn import FloxModule


@dataclass
class NodeState:
idx: FlockNodeID
"""The ID of the ``FlockNode`` that the ``NodeState`` corresponds with."""
Expand All @@ -14,14 +19,12 @@ class NodeState:
"""A dictionary containing extra data. This can be used as a temporary "store" to pass data between
callbacks for custom ``Strategy`` objects."""

def __init__(self, idx: FlockNodeID):
def __post_init__(self):
if type(self) is NodeState:
raise NotImplementedError(
"Cannot instantiate instance of ``NodeState`` (must instantiate instance of "
"subclasses: ``FloxAggregatorState`` or ``FloxWorkerState``)."
)
self.idx = idx
self.cache = {}

def __iter__(self) -> Iterable[str]:
"""Returns an iterator through the state's cache."""
Expand Down Expand Up @@ -72,17 +75,17 @@ def __init__(self, idx: FlockNodeID):
class FloxWorkerState(NodeState):
"""State of a Worker node in a ``Flock``."""

pre_local_train_model: torch.nn.Module | None = None
pre_local_train_model: FloxModule | None = None
"""Global model."""

post_local_train_model: torch.nn.Module | None = None
post_local_train_model: FloxModule | None = None
"""Local model after local fitting/training."""

def __init__(
self,
idx: FlockNodeID,
pre_local_train_model: torch.nn.Module | None = None,
post_local_train_model: torch.nn.Module | None = None,
pre_local_train_model: FloxModule | None = None,
post_local_train_model: FloxModule | None = None,
):
super().__init__(idx)
self.pre_local_train_model = pre_local_train_model
Expand Down
3 changes: 3 additions & 0 deletions flox/nn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def validation_step(
Returns:
"""
raise NotImplementedError

def test_step(self, batch: torch.Tensor | tuple[torch.Tensor, ...], batch_idx: int):
"""
Expand All @@ -58,6 +59,7 @@ def test_step(self, batch: torch.Tensor | tuple[torch.Tensor, ...], batch_idx: i
Returns:
"""
raise NotImplementedError

def predict_step(
self,
Expand All @@ -75,3 +77,4 @@ def predict_step(
Returns:
"""
raise NotImplementedError
4 changes: 1 addition & 3 deletions flox/nn/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import annotations

import datetime
from pathlib import Path
from typing import Any
Expand Down Expand Up @@ -34,7 +32,7 @@ def fit(
train_dataloader: DataLoader,
num_epochs: int,
strategy: Strategy,
node_state: FloxWorkerState | None = None,
node_state: FloxWorkerState,
valid_dataloader: DataLoader | None = None,
valid_ckpt_path: Path | str | None = None,
) -> pd.DataFrame:
Expand Down
4 changes: 0 additions & 4 deletions flox/nn/types.py

This file was deleted.

4 changes: 4 additions & 0 deletions flox/typing.py → flox/nn/typing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from typing import TypeAlias, Literal

import torch

Kind: TypeAlias = Literal["async", "sync"]
Where: TypeAlias = Literal["local", "globus_compute"]
StateDict = dict[str, torch.Tensor]
"""The state dict of PyTorch ``torch.nn.Module`` (see ``torch.nn.Module.state_dict()``)."""
46 changes: 26 additions & 20 deletions flox/runtime/fit.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import datetime
from typing import Any
import typing

import numpy as np
from pandas import DataFrame

from flox.data import FloxDataset
from flox.flock import Flock
from flox.nn import FloxModule

# from flox.run.fit_sync import sync_federated_fit
from flox.nn.types import Kind
from flox.nn.typing import Kind
from flox.runtime.launcher import (
GlobusComputeLauncher,
Launcher,
LocalLauncher,
ParslLauncher,
)
from flox.runtime.process.proc import BaseProcess
from flox.runtime.process.proc_async import AsyncProcess
from flox.runtime.process.proc_sync import SyncProcess
from flox.runtime.runtime import Runtime
Expand Down Expand Up @@ -47,8 +47,8 @@ def federated_fit(
num_global_rounds: int,
strategy: Strategy | str | None = None,
kind: Kind = "sync",
launcher: str = "process",
launcher_cfg: dict[str, Any] | None = None,
launcher_kind: str = "process",
launcher_cfg: dict[str, typing.Any] | None = None,
debug_mode: bool = False,
) -> tuple[FloxModule, DataFrame]:
"""
Expand All @@ -60,17 +60,18 @@ def federated_fit(
num_global_rounds (int):
strategy (Strategy | str | None):
kind (Kind):
launcher (Where):
launcher_cfg (dict[str, Any] | None):
launcher_kind (str):
launcher_cfg (dict[str, typing.Any] | None):
debug_mode (bool): ...
Returns:
The trained global module hosted on the leader of `flock`.
The history metrics from training.
"""
launcher_cfg = dict() if launcher_cfg is None else launcher_cfg
launcher = create_launcher(launcher, **launcher_cfg)
launcher = create_launcher(launcher_kind, **launcher_cfg)
transfer = BaseTransfer()
runtime = Runtime(launcher, transfer)

if strategy is None:
strategy = "fedsgd"
Expand All @@ -79,21 +80,26 @@ def federated_fit(

# runner = runner_factory.build(kind, ...)
# runner.start()

common_kwargs = {
"flock": flock,
"num_global_rounds": num_global_rounds,
"runtime": Runtime(launcher, transfer),
"module": module,
"dataset": datasets,
"strategy": strategy,
}

process: BaseProcess
match kind:
case "sync":
process = SyncProcess(**common_kwargs)
process = SyncProcess(
runtime=runtime,
flock=flock,
num_global_rounds=num_global_rounds,
module=module,
dataset=datasets,
strategy=strategy,
)
case "async":
process = AsyncProcess(**common_kwargs)
process = AsyncProcess(
runtime=runtime,
flock=flock,
num_global_rounds=num_global_rounds,
module=module,
dataset=datasets,
strategy=strategy,
)
case _:
raise ValueError

Expand Down
18 changes: 11 additions & 7 deletions flox/runtime/jobs/aggr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def aggregation_job(
"""
import pandas
from flox.flock.states import FloxAggregatorState
from flox.runtime import JobResult

child_states, child_state_dicts = {}, {}
for result in results:
Expand Down Expand Up @@ -47,7 +48,8 @@ def aggregation_job(
raise ValueError

history = pandas.concat(histories)
return transfer.report(node_state, node.idx, node.kind, avg_state_dict, history)
result = JobResult(node_state, node.idx, node.kind, avg_state_dict, history)
return transfer.report(result)


def debug_aggregation_job(
Expand All @@ -56,10 +58,13 @@ def debug_aggregation_job(
import datetime
import numpy
import pandas
from flox.flock.states import FloxAggregatorState
from flox.runtime import JobResult

result = next(iter(results))
module = result.module
node_state = dict(idx=node.idx)
state_dict = result.state_dict
state_dict = {} if state_dict is None else state_dict
node_state = FloxAggregatorState(node.idx)
history = {
"node/idx": [node.idx],
"node/kind": [node.kind.to_str()],
Expand All @@ -69,7 +74,6 @@ def debug_aggregation_job(
"train/time": [datetime.datetime.now()],
"mode": "debug",
}
history = pandas.DataFrame.from_dict(history)
return transfer.report(
node_state, node.idx, node.kind, module.state_dict(), history
)
history_df = pandas.DataFrame.from_dict(history)
result = JobResult(node_state, node.idx, node.kind, state_dict, history_df)
return transfer.report(result)
Loading

0 comments on commit d999d07

Please sign in to comment.