diff --git a/flox/data/core.py b/flox/data/core.py index 3b2950e..90a132e 100644 --- a/flox/data/core.py +++ b/flox/data/core.py @@ -11,7 +11,7 @@ from typing import TypeVar, Iterator from flox.flock.states import NodeState - from flox.flock import FlockNodeID + from flox.flock import NodeID T_co = TypeVar("T_co", covariant=True) @@ -25,7 +25,7 @@ def __init__(self): pass @abc.abstractmethod - def load(self, node: FlockNode | FlockNodeID): + def load(self, node: FlockNode | NodeID): pass @@ -34,13 +34,13 @@ class FederatedSubsets(FloxDataset): A subset... """ - def __init__(self, dataset: Dataset[T_co], indices: dict[FlockNodeID, list[int]]): + def __init__(self, dataset: Dataset[T_co], indices: dict[NodeID, list[int]]): super().__init__() self.dataset = dataset self.indices = indices self._num_subsets = len(list(self.indices)) - def load(self, node: FlockNode | FlockNodeID) -> Subset[T_co]: + def load(self, node: FlockNode | NodeID) -> Subset[T_co]: if isinstance(node, FlockNode): node = node.idx return Subset(self.dataset, self.indices[node]) @@ -49,13 +49,13 @@ def load(self, node: FlockNode | FlockNodeID) -> Subset[T_co]: def number_of_subsets(self): return self._num_subsets - def __getitem__(self, node: FlockNode | FlockNodeID) -> Subset[T_co]: + def __getitem__(self, node: FlockNode | NodeID) -> Subset[T_co]: return self.load(node) def __len__(self): return self._num_subsets - def __iter__(self) -> Iterator[tuple[FlockNodeID, Subset[T_co]]]: + def __iter__(self) -> Iterator[tuple[NodeID, Subset[T_co]]]: for idx in self.indices: yield idx, self.load(idx) @@ -68,11 +68,11 @@ class LocalDataset(FloxDataset): def __init__(self, state: NodeState, /, *args, **kwargs): super().__init__() - def load(self, node: FlockNode | FlockNodeID) -> Dataset[T_co]: + def load(self, node: FlockNode | NodeID) -> Dataset[T_co]: """Loads local dataset into a PyTorch object. Args: - node (FlockNode | FlockNodeID): ... + node (FlockNode | NodeID): ... Returns: Dataset object using local data. diff --git a/flox/data/utils.py b/flox/data/utils.py index ecefb70..59fc889 100644 --- a/flox/data/utils.py +++ b/flox/data/utils.py @@ -7,7 +7,7 @@ from torch.utils.data import Dataset, DataLoader from flox.data import FederatedSubsets -from flox.flock import Flock, FlockNodeID +from flox.flock import Flock, NodeID # TODO: Implement something similar for regression-based data. @@ -47,7 +47,7 @@ def federated_split( >>> data = MNIST() >>> subsets = federated_split(data, flock, num_classes=10, samples_alpha=1., labels_alpha=1.) >>> next(iter(subsets.items())) - >>> # (FlockNodeID(1), Subset(...)) # TODO: Run a real example and paste it here. + >>> # (NodeID(1), Subset(...)) # TODO: Run a real example and paste it here. Returns: A federated version of the dataset that is split up statistically based on the arguments alpha arguments. @@ -81,9 +81,9 @@ def federated_split( } label_probs = {w.idx: label_distr[i] for i, w in enumerate(flock.workers)} - indices: dict[FlockNodeID, list[int]] = defaultdict(list) + indices: dict[NodeID, list[int]] = defaultdict(list) loader = DataLoader(data, batch_size=1) - worker_samples: Counter[FlockNodeID] = Counter() + worker_samples: Counter[NodeID] = Counter() for idx, batch in enumerate(loader): _, y = batch label = y.item() diff --git a/flox/flock/__init__.py b/flox/flock/__init__.py index 37e7f2f..ca3f43e 100644 --- a/flox/flock/__init__.py +++ b/flox/flock/__init__.py @@ -3,14 +3,14 @@ """ from flox.flock.flock import Flock -from flox.flock.node import FlockNode, FlockNodeID, FlockNodeKind +from flox.flock.node import FlockNode, NodeID, NodeKind from flox.flock.states import AggrState, WorkerState, NodeState __all__ = [ "Flock", "FlockNode", - "FlockNodeID", - "FlockNodeKind", + "NodeID", + "NodeKind", "AggrState", "WorkerState", "NodeState", diff --git a/flox/flock/node.py b/flox/flock/node.py index f8cabc4..e622162 100644 --- a/flox/flock/node.py +++ b/flox/flock/node.py @@ -3,10 +3,10 @@ from uuid import UUID -FlockNodeID = int | str # NewType("FlockNodeID", int | str) +NodeID = int | str # NewType("NodeID", int | str) -class FlockNodeKind(Enum): +class NodeKind(Enum): """ The different kinds of nodes that can exist in a Flock topology. """ @@ -16,12 +16,12 @@ class FlockNodeKind(Enum): WORKER = auto() # leaf @staticmethod - def from_str(s: str) -> "FlockNodeKind": + def from_str(s: str) -> "NodeKind": """ Converts a string (namely, 'leader', 'aggregator', and 'worker') into their respective item in this Enum. For convenience, this function is *not* sensitive to capitalization or trailing whitespace (i.e., - `FlockNodeKind.from_str('LeaAder ')` and `FlockNodeKind.from_str('leader')` are both valid and equivalent). + `NodeKind.from_str('LeaAder ')` and `NodeKind.from_str('leader')` are both valid and equivalent). Args: s (str): String to convert into the respective Enum item. @@ -30,18 +30,18 @@ def from_str(s: str) -> "FlockNodeKind": ValueError: Thrown by illegal string values do not match the above description. Returns: - FlockNodeKind corresponding to the passed in String. + NodeKind corresponding to the passed in String. """ s = s.lower().strip() matches = { - "leader": FlockNodeKind.LEADER, - "aggregator": FlockNodeKind.AGGREGATOR, - "worker": FlockNodeKind.WORKER, + "leader": NodeKind.LEADER, + "aggregator": NodeKind.AGGREGATOR, + "worker": NodeKind.WORKER, } if s in matches: return matches[s] raise ValueError( - f"Illegal `str` value given to `FlockNodeKind.from_str()`. " + f"Illegal `str` value given to `NodeKind.from_str()`. " f"Must be one of the following: {list(matches.keys())}." ) @@ -50,12 +50,12 @@ def to_str(self) -> str: Returns the string representation of the Enum item. Returns: - String corresponding to the FlockNodeKind. + String corresponding to the NodeKind. """ matches = { - FlockNodeKind.LEADER: "leader", - FlockNodeKind.AGGREGATOR: "aggregator", - FlockNodeKind.WORKER: "worker", + NodeKind.LEADER: "leader", + NodeKind.AGGREGATOR: "aggregator", + NodeKind.WORKER: "worker", } return matches[self] @@ -66,18 +66,18 @@ class FlockNode: A node in a Flock. Args: - idx (FlockNodeID): The index of the node within the Flock as a whole (this is assigned by its `Flock`). - kind (FlockNodeKind): The kind of node. + idx (NodeID): The index of the node within the Flock as a whole (this is assigned by its `Flock`). + kind (NodeKind): The kind of node. globus_compute_endpoint (UUID | None): Required if you want to run fitting on Globus Compute; defaults to None. proxystore_endpoint (UUID | None): Required if you want to run fitting with Proxystore (recommended if you are using Globus Compute); defaults to None. """ - idx: FlockNodeID + idx: NodeID """Assigned during the Flock construction (i.e., not in .yaml/.json file)""" - kind: FlockNodeKind + kind: NodeKind """Which kind of node.""" globus_compute_endpoint: UUID | None = None diff --git a/flox/flock/states.py b/flox/flock/states.py index 7be7a60..bce0bf5 100644 --- a/flox/flock/states.py +++ b/flox/flock/states.py @@ -6,13 +6,13 @@ if typing.TYPE_CHECKING: from typing import Any, Iterable - from flox.flock import FlockNodeID + from flox.flock import NodeID from flox.nn import FloxModule @dataclass class NodeState: - idx: FlockNodeID + idx: NodeID """The ID of the ``FlockNode`` that the ``NodeState`` corresponds with.""" cache: dict[str, Any] = field(default_factory=dict) @@ -69,7 +69,7 @@ class AggrState(NodeState): """State of an Aggregator node in a ``Flock``.""" # TODO: If there is no difference between ``AggrState`` and ``NodeState``, then do we need the former at all? - def __init__(self, idx: FlockNodeID): + def __init__(self, idx: NodeID): super().__init__(idx) @@ -84,7 +84,7 @@ class WorkerState(NodeState): def __init__( self, - idx: FlockNodeID, + idx: NodeID, pre_local_train_model: FloxModule | None = None, post_local_train_model: FloxModule | None = None, ): diff --git a/flox/runtime/result.py b/flox/runtime/result.py index f7fb1cd..61a0d66 100644 --- a/flox/runtime/result.py +++ b/flox/runtime/result.py @@ -8,7 +8,7 @@ if typing.TYPE_CHECKING: from pandas import DataFrame - from flox.flock import FlockNodeID, FlockNodeKind + from flox.flock import NodeID, NodeKind from flox.flock.states import NodeState from flox.nn.typing import StateDict @@ -23,10 +23,10 @@ class JobResult: node_state: NodeState """The state of the ``Flock`` node based on its kind.""" - node_idx: FlockNodeID + node_idx: NodeID """The ID of the ``Flock`` node.""" - node_kind: FlockNodeKind + node_kind: NodeKind """The kind of the ``Flock`` node.""" state_dict: StateDict diff --git a/flox/runtime/transfer/base.py b/flox/runtime/transfer/base.py index 0427551..07e62c6 100644 --- a/flox/runtime/transfer/base.py +++ b/flox/runtime/transfer/base.py @@ -5,8 +5,8 @@ class BaseTransfer: # def report( # self, # node_state: NodeState | dict[str, Any] | None, - # node_idx: FlockNodeID | None, - # node_kind: FlockNodeKind | None, + # node_idx: NodeID | None, + # node_kind: NodeKind | None, # state_dict: StateDict | None, # history: DataFrame | None, # ) -> Result: