Skip to content

Commit

Permalink
FlockNodeKind ➡️ NodeKind, FlockNodeID ➡️ NodeID
Browse files Browse the repository at this point in the history
  • Loading branch information
nathaniel-hudson committed Mar 5, 2024
1 parent 8d3c4c3 commit 1f4a3c7
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 41 deletions.
16 changes: 8 additions & 8 deletions flox/data/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import TypeVar, Iterator

from flox.flock.states import NodeState
from flox.flock import FlockNodeID
from flox.flock import NodeID

T_co = TypeVar("T_co", covariant=True)

Expand All @@ -25,7 +25,7 @@ def __init__(self):
pass

@abc.abstractmethod
def load(self, node: FlockNode | FlockNodeID):
def load(self, node: FlockNode | NodeID):
pass


Expand All @@ -34,13 +34,13 @@ class FederatedSubsets(FloxDataset):
A subset...
"""

def __init__(self, dataset: Dataset[T_co], indices: dict[FlockNodeID, list[int]]):
def __init__(self, dataset: Dataset[T_co], indices: dict[NodeID, list[int]]):
super().__init__()
self.dataset = dataset
self.indices = indices
self._num_subsets = len(list(self.indices))

def load(self, node: FlockNode | FlockNodeID) -> Subset[T_co]:
def load(self, node: FlockNode | NodeID) -> Subset[T_co]:
if isinstance(node, FlockNode):
node = node.idx
return Subset(self.dataset, self.indices[node])
Expand All @@ -49,13 +49,13 @@ def load(self, node: FlockNode | FlockNodeID) -> Subset[T_co]:
def number_of_subsets(self):
return self._num_subsets

def __getitem__(self, node: FlockNode | FlockNodeID) -> Subset[T_co]:
def __getitem__(self, node: FlockNode | NodeID) -> Subset[T_co]:
return self.load(node)

def __len__(self):
return self._num_subsets

def __iter__(self) -> Iterator[tuple[FlockNodeID, Subset[T_co]]]:
def __iter__(self) -> Iterator[tuple[NodeID, Subset[T_co]]]:
for idx in self.indices:
yield idx, self.load(idx)

Expand All @@ -68,11 +68,11 @@ class LocalDataset(FloxDataset):
def __init__(self, state: NodeState, /, *args, **kwargs):
super().__init__()

def load(self, node: FlockNode | FlockNodeID) -> Dataset[T_co]:
def load(self, node: FlockNode | NodeID) -> Dataset[T_co]:
"""Loads local dataset into a PyTorch object.
Args:
node (FlockNode | FlockNodeID): ...
node (FlockNode | NodeID): ...
Returns:
Dataset object using local data.
Expand Down
8 changes: 4 additions & 4 deletions flox/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.utils.data import Dataset, DataLoader

from flox.data import FederatedSubsets
from flox.flock import Flock, FlockNodeID
from flox.flock import Flock, NodeID


# TODO: Implement something similar for regression-based data.
Expand Down Expand Up @@ -47,7 +47,7 @@ def federated_split(
>>> data = MNIST()
>>> subsets = federated_split(data, flock, num_classes=10, samples_alpha=1., labels_alpha=1.)
>>> next(iter(subsets.items()))
>>> # (FlockNodeID(1), Subset(...)) # TODO: Run a real example and paste it here.
>>> # (NodeID(1), Subset(...)) # TODO: Run a real example and paste it here.
Returns:
A federated version of the dataset that is split up statistically based on the arguments alpha arguments.
Expand Down Expand Up @@ -81,9 +81,9 @@ def federated_split(
}
label_probs = {w.idx: label_distr[i] for i, w in enumerate(flock.workers)}

indices: dict[FlockNodeID, list[int]] = defaultdict(list)
indices: dict[NodeID, list[int]] = defaultdict(list)
loader = DataLoader(data, batch_size=1)
worker_samples: Counter[FlockNodeID] = Counter()
worker_samples: Counter[NodeID] = Counter()
for idx, batch in enumerate(loader):
_, y = batch
label = y.item()
Expand Down
6 changes: 3 additions & 3 deletions flox/flock/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
"""

from flox.flock.flock import Flock
from flox.flock.node import FlockNode, FlockNodeID, FlockNodeKind
from flox.flock.node import FlockNode, NodeID, NodeKind
from flox.flock.states import AggrState, WorkerState, NodeState

__all__ = [
"Flock",
"FlockNode",
"FlockNodeID",
"FlockNodeKind",
"NodeID",
"NodeKind",
"AggrState",
"WorkerState",
"NodeState",
Expand Down
34 changes: 17 additions & 17 deletions flox/flock/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from uuid import UUID


FlockNodeID = int | str # NewType("FlockNodeID", int | str)
NodeID = int | str # NewType("NodeID", int | str)


class FlockNodeKind(Enum):
class NodeKind(Enum):
"""
The different kinds of nodes that can exist in a Flock topology.
"""
Expand All @@ -16,12 +16,12 @@ class FlockNodeKind(Enum):
WORKER = auto() # leaf

@staticmethod
def from_str(s: str) -> "FlockNodeKind":
def from_str(s: str) -> "NodeKind":
"""
Converts a string (namely, 'leader', 'aggregator', and 'worker') into their respective item in this Enum.
For convenience, this function is *not* sensitive to capitalization or trailing whitespace (i.e.,
`FlockNodeKind.from_str('LeaAder ')` and `FlockNodeKind.from_str('leader')` are both valid and equivalent).
`NodeKind.from_str('LeaAder ')` and `NodeKind.from_str('leader')` are both valid and equivalent).
Args:
s (str): String to convert into the respective Enum item.
Expand All @@ -30,18 +30,18 @@ def from_str(s: str) -> "FlockNodeKind":
ValueError: Thrown by illegal string values do not match the above description.
Returns:
FlockNodeKind corresponding to the passed in String.
NodeKind corresponding to the passed in String.
"""
s = s.lower().strip()
matches = {
"leader": FlockNodeKind.LEADER,
"aggregator": FlockNodeKind.AGGREGATOR,
"worker": FlockNodeKind.WORKER,
"leader": NodeKind.LEADER,
"aggregator": NodeKind.AGGREGATOR,
"worker": NodeKind.WORKER,
}
if s in matches:
return matches[s]
raise ValueError(
f"Illegal `str` value given to `FlockNodeKind.from_str()`. "
f"Illegal `str` value given to `NodeKind.from_str()`. "
f"Must be one of the following: {list(matches.keys())}."
)

Expand All @@ -50,12 +50,12 @@ def to_str(self) -> str:
Returns the string representation of the Enum item.
Returns:
String corresponding to the FlockNodeKind.
String corresponding to the NodeKind.
"""
matches = {
FlockNodeKind.LEADER: "leader",
FlockNodeKind.AGGREGATOR: "aggregator",
FlockNodeKind.WORKER: "worker",
NodeKind.LEADER: "leader",
NodeKind.AGGREGATOR: "aggregator",
NodeKind.WORKER: "worker",
}
return matches[self]

Expand All @@ -66,18 +66,18 @@ class FlockNode:
A node in a Flock.
Args:
idx (FlockNodeID): The index of the node within the Flock as a whole (this is assigned by its `Flock`).
kind (FlockNodeKind): The kind of node.
idx (NodeID): The index of the node within the Flock as a whole (this is assigned by its `Flock`).
kind (NodeKind): The kind of node.
globus_compute_endpoint (UUID | None): Required if you want to run fitting on Globus Compute;
defaults to None.
proxystore_endpoint (UUID | None): Required if you want to run fitting with Proxystore
(recommended if you are using Globus Compute); defaults to None.
"""

idx: FlockNodeID
idx: NodeID
"""Assigned during the Flock construction (i.e., not in .yaml/.json file)"""

kind: FlockNodeKind
kind: NodeKind
"""Which kind of node."""

globus_compute_endpoint: UUID | None = None
Expand Down
8 changes: 4 additions & 4 deletions flox/flock/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
if typing.TYPE_CHECKING:
from typing import Any, Iterable

from flox.flock import FlockNodeID
from flox.flock import NodeID
from flox.nn import FloxModule


@dataclass
class NodeState:
idx: FlockNodeID
idx: NodeID
"""The ID of the ``FlockNode`` that the ``NodeState`` corresponds with."""

cache: dict[str, Any] = field(default_factory=dict)
Expand Down Expand Up @@ -69,7 +69,7 @@ class AggrState(NodeState):
"""State of an Aggregator node in a ``Flock``."""

# TODO: If there is no difference between ``AggrState`` and ``NodeState``, then do we need the former at all?
def __init__(self, idx: FlockNodeID):
def __init__(self, idx: NodeID):
super().__init__(idx)


Expand All @@ -84,7 +84,7 @@ class WorkerState(NodeState):

def __init__(
self,
idx: FlockNodeID,
idx: NodeID,
pre_local_train_model: FloxModule | None = None,
post_local_train_model: FloxModule | None = None,
):
Expand Down
6 changes: 3 additions & 3 deletions flox/runtime/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
if typing.TYPE_CHECKING:
from pandas import DataFrame

from flox.flock import FlockNodeID, FlockNodeKind
from flox.flock import NodeID, NodeKind
from flox.flock.states import NodeState
from flox.nn.typing import StateDict

Expand All @@ -23,10 +23,10 @@ class JobResult:
node_state: NodeState
"""The state of the ``Flock`` node based on its kind."""

node_idx: FlockNodeID
node_idx: NodeID
"""The ID of the ``Flock`` node."""

node_kind: FlockNodeKind
node_kind: NodeKind
"""The kind of the ``Flock`` node."""

state_dict: StateDict
Expand Down
4 changes: 2 additions & 2 deletions flox/runtime/transfer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ class BaseTransfer:
# def report(
# self,
# node_state: NodeState | dict[str, Any] | None,
# node_idx: FlockNodeID | None,
# node_kind: FlockNodeKind | None,
# node_idx: NodeID | None,
# node_kind: NodeKind | None,
# state_dict: StateDict | None,
# history: DataFrame | None,
# ) -> Result:
Expand Down

0 comments on commit 1f4a3c7

Please sign in to comment.