Skip to content

Commit

Permalink
fix mypy (wip)
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-janidlo committed Feb 23, 2024
1 parent bed8625 commit a7b298d
Show file tree
Hide file tree
Showing 25 changed files with 153 additions and 99 deletions.
10 changes: 9 additions & 1 deletion flox/backends/launcher/impl_base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from abc import ABC, abstractmethod
from concurrent.futures import Future
from typing import Any, Protocol

from flox.flock import FlockNode


class LauncherFunction(Protocol):
def __call__(self, node: FlockNode, *args: Any, **kwargs: Any) -> Any:
...


class Launcher(ABC):
"""
Base class for launching functions in an FL process.
Expand All @@ -14,7 +20,9 @@ def __init__(self):
pass

@abstractmethod
def submit(self, fn, node: FlockNode, /, *args, **kwargs) -> Future:
def submit(
self, fn: LauncherFunction, node: FlockNode, /, *args, **kwargs
) -> Future:
raise NotImplementedError()

@abstractmethod
Expand Down
8 changes: 2 additions & 6 deletions flox/backends/launcher/impl_globus.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from collections.abc import Callable
from concurrent.futures import Future
from typing import Any

import globus_compute_sdk

from flox.backends.launcher.impl_base import Launcher
from flox.backends.launcher.impl_base import Launcher, LauncherFunction
from flox.flock import FlockNode


Expand All @@ -13,15 +11,13 @@ class GlobusComputeLauncher(Launcher):
Class that executes tasks on Globus Compute.
"""

_globus_compute_executor: globus_compute_sdk.Executor | None = None

def __init__(self):
super().__init__()
if self._globus_compute_executor is None:
self._globus_compute_executor = globus_compute_sdk.Executor()

def submit(
self, fn: Callable[[FlockNode, ...], Any], node: FlockNode, /, *args, **kwargs
self, fn: LauncherFunction, node: FlockNode, /, *args, **kwargs
) -> Future:
endpoint_id = node.globus_compute_endpoint
self._globus_compute_executor.endpoint_id = endpoint_id
Expand Down
9 changes: 6 additions & 3 deletions flox/backends/launcher/impl_local.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
from concurrent.futures import Executor, Future, ProcessPoolExecutor, ThreadPoolExecutor

from flox.backends.launcher.impl_base import Launcher
from flox.backends.launcher.impl_base import Launcher, LauncherFunction
from flox.flock import FlockNode


Expand All @@ -12,6 +12,7 @@ class LocalLauncher(Launcher):
def __init__(self, pool: str, n_workers: int = 1):
super().__init__()
self.n_workers = n_workers
self.pool: Executor
if pool == "process":
self.pool = ProcessPoolExecutor(n_workers)
elif pool == "thread":
Expand All @@ -21,7 +22,9 @@ def __init__(self, pool: str, n_workers: int = 1):
"Illegal value for argument `pool`. Must be either 'pool' or 'thread'."
)

def submit(self, fn, node: FlockNode, /, *args, **kwargs) -> Future:
def submit(
self, fn: LauncherFunction, node: FlockNode, /, *args, **kwargs
) -> Future:
return self.pool.submit(fn, node, *args, **kwargs)

def collect(self):
Expand Down
10 changes: 6 additions & 4 deletions flox/backends/launcher/impl_parsl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from concurrent.futures import Future

from flox.backends.launcher.impl_base import Launcher
from flox.backends.launcher.impl_base import Launcher, LauncherFunction
from flox.flock import FlockNode


Expand All @@ -13,8 +13,10 @@ def __init__(self):
super().__init__()
raise NotImplementedError(f"{self.__name__} yet implemented")

def submit(self, fn, node: FlockNode, /, *args, **kwargs) -> Future:
pass
def submit(
self, fn: LauncherFunction, node: FlockNode, /, *args, **kwargs
) -> Future:
raise NotImplementedError()

def collect(self):
pass
raise NotImplementedError()
4 changes: 2 additions & 2 deletions flox/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
FLoX includes utility functions to simplify the conversion from a standard, centralized PyTorch dataset to a
simulated, decentralized dataset.
"""
from flox.data.core import FloxDataset
from flox.data.core import FederatedSubsets, FloxDataset
from flox.data.utils import fed_barplot, federated_split

__all__ = ["FloxDataset", "fed_barplot", "federated_split"]
__all__ = ["FloxDataset", "FederatedSubsets", "fed_barplot", "federated_split"]
20 changes: 15 additions & 5 deletions flox/data/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Mapping
from enum import IntEnum, auto
from typing import NewType, TypeVar, Union
from typing import NewType, Union, get_args

from torch.utils.data import Dataset, Subset

Expand All @@ -16,11 +17,21 @@ class FloxDatasetKind(IntEnum):
def from_obj(obj) -> "FloxDatasetKind":
if isinstance(obj, Dataset):
return FloxDatasetKind.STANDARD
elif isinstance(obj, FederatedSubsets):
elif FloxDatasetKind.is_federated_dataset(obj):
return FloxDatasetKind.FEDERATED
else:
return FloxDatasetKind.INVALID

@staticmethod
def is_federated_dataset(obj) -> bool:
if not isinstance(obj, Mapping):
return False

return all(
isinstance(k, get_args(FlockNodeID)) and isinstance(v, (Dataset, Subset))
for k, v in obj.items()
)


def flox_compatible_data(obj) -> bool:
kind = FloxDatasetKind.from_obj(obj)
Expand All @@ -29,9 +40,8 @@ def flox_compatible_data(obj) -> bool:
return True


T_co = TypeVar("T_co", covariant=True)
FederatedSubsets = NewType(
"FederatedSubsets", dict[FlockNodeID, Union[Dataset[T_co], Subset[T_co]]]
"FederatedSubsets", Mapping[FlockNodeID, Union[Dataset, Subset]]
)


Expand All @@ -41,4 +51,4 @@ def __init__(self, state: NodeState, /, *args, **kwargs):
self.state = state


FloxDataset = NewType("FloxDataset", Union[MyFloxDataset, FederatedSubsets])
FloxDataset = Union[MyFloxDataset, FederatedSubsets]
19 changes: 12 additions & 7 deletions flox/data/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from collections import defaultdict
from collections import Counter, defaultdict

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -9,6 +9,7 @@

from flox.data import FederatedSubsets
from flox.flock import Flock
from flox.flock.node import FlockNodeID


# TODO: Implement something similar for regression-based data.
Expand Down Expand Up @@ -59,16 +60,20 @@ def federated_split(
sample_distr = stats.dirichlet(np.full(num_workers, samples_alpha))
label_distr = stats.dirichlet(np.full(num_classes, labels_alpha))

num_samples_for_workers = (sample_distr.rvs()[0] * len(data)).astype(int)
# pytorch intentionally doesn't define an empty __len__ for DataSet, even though
# most subclasses implement it
data_count = len(data) # type: ignore

num_samples_for_workers = (sample_distr.rvs()[0] * data_count).astype(int)
num_samples_for_workers = {
worker.idx: num_samples
for worker, num_samples in zip(flock.workers, num_samples_for_workers)
}
label_probs = {w.idx: label_distr.rvs()[0] for w in flock.workers}

indices: dict[int, list[int]] = defaultdict(list)
indices: dict[FlockNodeID, list[int]] = defaultdict(list)
loader = DataLoader(data, batch_size=1)
worker_samples = defaultdict(int)
worker_samples: Counter[FlockNodeID] = Counter()
for idx, batch in enumerate(loader):
_, y = batch
label = y.item()
Expand All @@ -89,11 +94,11 @@ def federated_split(
)
raise err

probs = np.array(probs)
probs = probs / probs.sum()
probs_norm = np.array(probs)
probs_norm = probs_norm / probs_norm.sum()

if len(temp_workers) > 0:
chosen_worker = np.random.choice(temp_workers, p=probs)
chosen_worker = np.random.choice(temp_workers, p=probs_norm)
indices[chosen_worker].append(idx)
worker_samples[chosen_worker] += 1

Expand Down
15 changes: 8 additions & 7 deletions flox/flock/flock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import functools
import json
from collections.abc import Generator
from collections.abc import Iterator
from pathlib import Path
from typing import Any
from uuid import UUID
Expand Down Expand Up @@ -55,7 +55,6 @@ def __init__(self, topo: nx.DiGraph | None = None, _src: Path | str | None = Non
"""
self.node_counter: int = 0
self._src = "interactive" if _src is None else _src
self.leader = None

if topo is None:
# By default (i.e., `topo is None`),
Expand Down Expand Up @@ -84,6 +83,8 @@ def __init__(self, topo: nx.DiGraph | None = None, _src: Path | str | None = Non
raise ValueError(
"A legal Flock cannot have more than one leader."
)
if not found_leader:
raise ValueError("A legal Flock must have a leader.")

def add_node(
self,
Expand All @@ -102,7 +103,7 @@ def add_node(
proxystore_endpoint_id=proxystore_endpoint_id,
)
self.node_counter += 1
return FlockNodeID(idx)
return idx

def add_edge(self, u: FlockNodeID, v: FlockNodeID, **attrs) -> None:
"""
Expand Down Expand Up @@ -218,7 +219,7 @@ def validate_topo(self) -> bool:

return True

def children(self, node: FlockNode | FlockNodeID | int) -> Generator[FlockNode]:
def children(self, node: FlockNode | FlockNodeID | int) -> Iterator[FlockNode]:
if isinstance(node, FlockNode):
idx = node.idx
else:
Expand Down Expand Up @@ -384,7 +385,7 @@ def proxystore_ready(self) -> bool:
# return self.nodes(by_kind=FlockNodeKind.LEADER)

@property
def aggregators(self) -> Generator[FlockNode]:
def aggregators(self) -> Iterator[FlockNode]:
"""
The aggregator nodes of the Flock.
Expand All @@ -394,7 +395,7 @@ def aggregators(self) -> Generator[FlockNode]:
return self.nodes(by_kind=FlockNodeKind.AGGREGATOR)

@property
def workers(self) -> Generator[FlockNode]:
def workers(self) -> Iterator[FlockNode]:
"""
The worker nodes of the Flock.
Expand All @@ -413,7 +414,7 @@ def number_of_workers(self) -> int:
"""The number of worker nodes in the Flock."""
return len(list(self.workers))

def nodes(self, by_kind: FlockNodeKind | None = None) -> Generator[FlockNode]:
def nodes(self, by_kind: FlockNodeKind | None = None) -> Iterator[FlockNode]:
for idx, data in self.topo.nodes(data=True):
if by_kind is not None and data["kind"] != by_kind:
continue
Expand Down
4 changes: 2 additions & 2 deletions flox/flock/node.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from dataclasses import dataclass
from enum import Enum, auto
from typing import NewType, Union
from typing import Union
from uuid import UUID

FlockNodeID = NewType("FlockNodeID", Union[int, str])
FlockNodeID = Union[int, str]


class FlockNodeKind(Enum):
Expand Down
2 changes: 1 addition & 1 deletion flox/nn/logger/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def clear(self) -> None:

def to_pandas(self) -> pd.DataFrame:
df = pd.DataFrame.from_records(self.records)
for col in df:
for col in df.columns:
if "time" in col:
df[col] = pd.to_datetime(df[col])
return df
6 changes: 4 additions & 2 deletions flox/nn/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations
from abc import ABC, abstractmethod

import torch


class FloxModule(torch.nn.Module):
class FloxModule(torch.nn.Module, ABC):
"""
The ``FloxModule`` is a wrapper for the standard ``torch.nn.Module`` class from PyTorch, with
a lot of inspiration from the ``lightning.LightningModule`` class from PyTorch Lightning.
Expand All @@ -12,6 +12,7 @@ class FloxModule(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@abstractmethod
def training_step(
self, batch: torch.Tensor | tuple[torch.Tensor, torch.Tensor], batch_idx: int
) -> torch.Tensor:
Expand All @@ -26,6 +27,7 @@ def training_step(
Loss from the training step.
"""

@abstractmethod
def configure_optimizers(self) -> torch.optim.Optimizer:
"""Configures, initializes, and returns the optimizer used to train the model.
Expand Down
7 changes: 5 additions & 2 deletions flox/nn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,13 @@ def fit(
loss.backward()

try:
assert strategy is not None
assert node_state is not None
strategy.wrk_on_after_train_step(node_state, loss)
except NotImplementedError:
except (AttributeError, AssertionError):
"""
The current strategy does not override the `wrk_on_after_train_step()` callback.
node_state is None, strategy is None, or the strategy doesn't
implement `wrk_on_after_train_step()`.
"""
pass

Expand Down
2 changes: 1 addition & 1 deletion flox/run/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def federated_fit(
"""
launcher_cfg = dict() if launcher_cfg is None else launcher_cfg
launcher = create_launcher(launcher, **launcher_cfg)
# launcher = create_launcher(launcher, **launcher_cfg) # not used

strategy = "fedsgd" if strategy is None else strategy

Expand Down
Loading

0 comments on commit a7b298d

Please sign in to comment.