Skip to content

Commit

Permalink
Rename NodeStates classes
Browse files Browse the repository at this point in the history
- `FloxAggregatorState` ➡️ `AggrState`
- `FloxWorkerState` ➡️ `WorkerState`
  • Loading branch information
nathaniel-hudson committed Feb 28, 2024
1 parent 871742c commit 887b382
Show file tree
Hide file tree
Showing 16 changed files with 63 additions and 65 deletions.
6 changes: 3 additions & 3 deletions flox/flock/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@

from flox.flock.flock import Flock
from flox.flock.node import FlockNode, FlockNodeID, FlockNodeKind
from flox.flock.states import FloxAggregatorState, FloxWorkerState, NodeState
from flox.flock.states import AggrState, WorkerState, NodeState

__all__ = [
"Flock",
"FlockNode",
"FlockNodeID",
"FlockNodeKind",
"FloxAggregatorState",
"FloxWorkerState",
"AggrState",
"WorkerState",
"NodeState",
]
19 changes: 7 additions & 12 deletions flox/flock/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __post_init__(self):
if type(self) is NodeState:
raise NotImplementedError(
"Cannot instantiate instance of ``NodeState`` (must instantiate instance of "
"subclasses: ``FloxAggregatorState`` or ``FloxWorkerState``)."
"subclasses: ``AggrState`` or ``WorkerState``)."
)

def __iter__(self) -> Iterable[str]:
Expand All @@ -42,7 +42,7 @@ def __setitem__(self, key: str, value: Any) -> None:
value (Any): Data to store in ``self.extra_data``.
Examples:
>>> state = FloxWorkerState(...)
>>> state = WorkerState(...)
>>> state["foo"] = "bar"
"""
self.cache[key] = value
Expand All @@ -54,7 +54,7 @@ def __getitem__(self, key: str) -> Any:
key (str): Key to retrieve stored data in ``self.extra_data``.
Examples:
>>> state = FloxWorkerState(...)
>>> state = WorkerState(...)
>>> state["foo"] = "bar" # Stores the data (see `__setitem__()`).
>>> print(state["foo"]) # Gets the item.
>>> # "foo"
Expand All @@ -65,14 +65,15 @@ def __getitem__(self, key: str) -> Any:
return self.cache[key]


class FloxAggregatorState(NodeState):
class AggrState(NodeState):
"""State of an Aggregator node in a ``Flock``."""

# TODO: If there is no difference between ``AggrState`` and ``NodeState``, then do we need the former at all?
def __init__(self, idx: FlockNodeID):
super().__init__(idx)


class FloxWorkerState(NodeState):
class WorkerState(NodeState):
"""State of a Worker node in a ``Flock``."""

pre_local_train_model: FloxModule | None = None
Expand All @@ -92,11 +93,5 @@ def __init__(
self.post_local_train_model = post_local_train_model

def __repr__(self) -> str:
template = (
"FloxWorkerState(pre_local_train_model={}, post_local_train_model={})"
)
template = "WorkerState(pre_local_train_model={}, post_local_train_model={})"
return template.format(self.pre_local_train_model, self.post_local_train_model)


# NodeState = NewType("NodeState", Union[FloxAggregatorState, FloxWorkerState])
# """A `Type` included for convenience. It is equivalent to ``Union[FloxAggregatorState, FloxWorkerState]``."""
4 changes: 2 additions & 2 deletions flox/nn/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch.utils.data import DataLoader

from flox.flock.states import FloxWorkerState
from flox.flock.states import WorkerState
from flox.nn import FloxModule
from flox.nn.logger.csv import CSVLogger
from flox.strategies import Strategy
Expand All @@ -32,7 +32,7 @@ def fit(
train_dataloader: DataLoader,
num_epochs: int,
strategy: Strategy,
node_state: FloxWorkerState,
node_state: WorkerState,
valid_dataloader: DataLoader | None = None,
valid_ckpt_path: Path | str | None = None,
) -> pd.DataFrame:
Expand Down
2 changes: 1 addition & 1 deletion flox/runtime/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def federated_fit(
strategy=strategy,
)
case _:
raise ValueError
raise ValueError("Illegal value for the strategy `kind` parameter.")

start_time = datetime.datetime.now()
module, history = process.start(debug_mode)
Expand Down
8 changes: 4 additions & 4 deletions flox/runtime/jobs/aggr.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def aggregation_job(
Aggregation results.
"""
import pandas
from flox.flock.states import FloxAggregatorState, NodeState
from flox.flock.states import AggrState, NodeState
from flox.runtime import JobResult

child_states: dict[FlockNodeID, NodeState] = {}
Expand All @@ -29,7 +29,7 @@ def aggregation_job(
child_states[idx] = result.node_state
child_state_dicts[idx] = result.state_dict

node_state = FloxAggregatorState(node.idx)
node_state = AggrState(node.idx)
avg_state_dict = strategy.agg_param_aggregation(
node_state, child_states, child_state_dicts
)
Expand Down Expand Up @@ -59,13 +59,13 @@ def debug_aggregation_job(
import datetime
import numpy
import pandas
from flox.flock.states import FloxAggregatorState
from flox.flock.states import AggrState
from flox.runtime import JobResult

result = next(iter(results))
state_dict = result.state_dict
state_dict = {} if state_dict is None else state_dict
node_state = FloxAggregatorState(node.idx)
node_state = AggrState(node.idx)
history = {
"node/idx": [node.idx],
"node/kind": [node.kind.to_str()],
Expand Down
8 changes: 4 additions & 4 deletions flox/runtime/jobs/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def local_training_job(
Local fitting results.
"""
from copy import deepcopy
from flox.flock.states import FloxWorkerState
from flox.flock.states import WorkerState
from flox.nn.trainer import Trainer
from torch.utils.data import DataLoader
from flox.runtime import JobResult
Expand All @@ -66,7 +66,7 @@ def local_training_job(
global_model.load_state_dict(module_state_dict)
local_model.load_state_dict(module_state_dict)

node_state = FloxWorkerState(
node_state = WorkerState(
node.idx, pre_local_train_model=global_model, post_local_train_model=local_model
)

Expand Down Expand Up @@ -122,11 +122,11 @@ def debug_training_job(
import datetime
import numpy as np
import pandas
from flox.flock.states import FloxWorkerState
from flox.flock.states import WorkerState
from flox.runtime import JobResult

local_module = module
node_state = FloxWorkerState(
node_state = WorkerState(
node.idx,
pre_local_train_model=local_module,
post_local_train_model=local_module,
Expand Down
1 change: 1 addition & 0 deletions flox/runtime/launcher/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ def submit(self, fn, node: FlockNode, /, *args, **kwargs) -> Future:

@abstractmethod
def collect(self):
# TODO: Check if this is needed at all.
raise NotImplementedError()
16 changes: 10 additions & 6 deletions flox/runtime/process/future_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

import functools
import typing
from concurrent.futures import Future

from flox.flock import FlockNode
from flox.runtime.jobs import aggregation_job
from flox.runtime.runtime import Runtime
from flox.runtime.utils import set_parent_future
from flox.strategies import Strategy
if typing.TYPE_CHECKING:
from concurrent.futures import Future

from flox.flock import FlockNode
from flox.runtime.jobs import aggregation_job
from flox.runtime.runtime import Runtime
from flox.runtime.utils import set_parent_future
from flox.strategies import Strategy


def all_child_futures_finished_cbk(
Expand Down
6 changes: 3 additions & 3 deletions flox/runtime/process/proc_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from flox.data import FloxDataset
from flox.flock import Flock, FlockNodeID
from flox.flock.states import FloxAggregatorState, FloxWorkerState, NodeState
from flox.flock.states import AggrState, WorkerState, NodeState
from flox.nn import FloxModule
from flox.runtime.jobs import local_training_job
from flox.runtime.process.proc import BaseProcess
Expand Down Expand Up @@ -56,7 +56,7 @@ def __init__(

self.state_dict = None
self.debug_mode = False
self.state = FloxAggregatorState(self.flock.leader.idx)
self.state = AggrState(self.flock.leader.idx)

def start(self, debug_mode: bool = False) -> tuple[FloxModule, DataFrame]:
if not self.flock.two_tier:
Expand All @@ -68,7 +68,7 @@ def start(self, debug_mode: bool = False) -> tuple[FloxModule, DataFrame]:
worker_state_dicts: dict[FlockNodeID, StateDict] = {}
for worker in self.flock.workers:
worker_rounds[worker.idx] = 0
worker_states[worker.idx] = FloxWorkerState(worker.idx)
worker_states[worker.idx] = WorkerState(worker.idx)
worker_state_dicts[worker.idx] = self.global_module.state_dict()

futures = set()
Expand Down
4 changes: 2 additions & 2 deletions flox/runtime/process/proc_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from flox.data import FloxDataset
from flox.flock import Flock, FlockNode, FlockNodeKind
from flox.flock.states import FloxAggregatorState
from flox.flock.states import AggrState
from flox.nn import FloxModule
from flox.runtime.jobs import local_training_job, debug_training_job
from flox.runtime.process.future_callbacks import all_child_futures_finished_cbk
Expand Down Expand Up @@ -123,7 +123,7 @@ def step(
raise ValueError(value_err_template.format(kind, idx))

def _aggr_job(self, node: FlockNode) -> Future[Result]:
aggr_state = FloxAggregatorState(node.idx)
aggr_state = AggrState(node.idx)
self.strategy.cli_worker_selection(aggr_state, list(self.flock.children(node)))
# FIXME: This (^^^) shouldn't be run on the aggregator
children_futures = [
Expand Down
24 changes: 12 additions & 12 deletions flox/strategies/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from typing import Iterable, MutableMapping, TypeAlias
from flox.flock import FlockNode, FlockNodeID
from flox.flock.states import FloxWorkerState, FloxAggregatorState, NodeState
from flox.flock.states import WorkerState, AggrState, NodeState
from flox.nn.typing import StateDict

Loss: TypeAlias = torch.Tensor
Expand Down Expand Up @@ -65,7 +65,7 @@ def cli_get_node_statuses(self):

@abstractmethod
def cli_worker_selection(
self, state: FloxAggregatorState, children: Iterable[FlockNode], *args, **kwargs
self, state: AggrState, children: Iterable[FlockNode], *args, **kwargs
) -> Iterable[FlockNode]:
"""
Expand All @@ -81,15 +81,15 @@ def cli_worker_selection(
return children

def cli_before_share_params(
self, state: FloxAggregatorState, state_dict: StateDict, *args, **kwargs
self, state: AggrState, state_dict: StateDict, *args, **kwargs
) -> StateDict:
"""Callback before sharing parameters to child nodes.
This is mostly done is modify the global model's StateDict. This can be done to encrypt the
model parameters, apply noise, personalize, etc.
Args:
state (FloxAggregatorState): The current state of the aggregator.
state (AggrState): The current state of the aggregator.
state_dict (StateDict): The global model's current StateDict (i.e., parameters) before
sharing with workers.
Expand All @@ -102,18 +102,18 @@ def cli_before_share_params(
# AGGREGATOR CALLBACKS. #
####################################################################################

def agg_before_round(self, state: FloxAggregatorState) -> None:
def agg_before_round(self, state: AggrState) -> None:
"""
Some process to run at the start of a round.
Args:
state (FloxAggregatorState): The current state of the Aggregator FloxNode.
state (AggrState): The current state of the Aggregator FloxNode.
"""
raise NotImplementedError

def agg_param_aggregation(
self,
state: FloxAggregatorState,
state: AggrState,
children_states: MutableMapping[FlockNodeID, NodeState],
children_state_dicts: MutableMapping[FlockNodeID, StateDict],
*args,
Expand All @@ -122,7 +122,7 @@ def agg_param_aggregation(
"""
Args:
state (FloxAggregatorState):
state (AggrState):
children_states (Mapping[FlockNodeID, NodeState]):
children_state_dicts (Mapping[FlockNodeID, NodeState]):
*args ():
Expand All @@ -138,7 +138,7 @@ def agg_param_aggregation(
####################################################################################

def wrk_on_recv_params(
self, state: FloxWorkerState, params: StateDict, *args, **kwargs
self, state: WorkerState, params: StateDict, *args, **kwargs
):
"""
Expand All @@ -153,7 +153,7 @@ def wrk_on_recv_params(
"""
return params

def wrk_before_train_step(self, state: FloxWorkerState, *args, **kwargs):
def wrk_before_train_step(self, state: WorkerState, *args, **kwargs):
"""
Args:
Expand All @@ -167,7 +167,7 @@ def wrk_before_train_step(self, state: FloxWorkerState, *args, **kwargs):
raise NotImplementedError()

def wrk_after_train_step(
self, state: FloxWorkerState, loss: Loss, *args, **kwargs
self, state: WorkerState, loss: Loss, *args, **kwargs
) -> Loss:
"""
Expand All @@ -183,7 +183,7 @@ def wrk_after_train_step(
return loss

def wrk_before_submit_params(
self, state: FloxWorkerState, *args, **kwargs
self, state: WorkerState, *args, **kwargs
) -> StateDict:
"""
Expand Down
6 changes: 3 additions & 3 deletions flox/strategies/registry/fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections.abc import Mapping

from flox.flock import FlockNodeID
from flox.flock.states import FloxAggregatorState, FloxWorkerState, NodeState
from flox.flock.states import AggrState, WorkerState, NodeState
from flox.nn.typing import StateDict


Expand Down Expand Up @@ -46,14 +46,14 @@ def __init__(
participation, probabilistic, always_include_child_aggregators, seed
)

def wrk_before_train_step(self, state: FloxWorkerState, *args, **kwargs):
def wrk_before_train_step(self, state: WorkerState, *args, **kwargs):
if "dataset" not in kwargs:
raise ValueError("`dataset` must be provided")
state["num_data_samples"] = len(kwargs["dataset"])

def agg_param_aggregation(
self,
state: FloxAggregatorState,
state: AggrState,
children_states: Mapping[FlockNodeID, NodeState],
children_state_dicts: Mapping[FlockNodeID, StateDict],
*_args,
Expand Down
6 changes: 3 additions & 3 deletions flox/strategies/registry/fedprox.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from flox.flock.states import FloxWorkerState
from flox.flock.states import WorkerState
from flox.strategies import FedAvg


Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(

def wrk_after_train_step(
self,
state: FloxWorkerState,
state: WorkerState,
loss: torch.Tensor,
*args,
**kwargs,
Expand All @@ -66,7 +66,7 @@ def wrk_after_train_step(
$$
Args:
state (FloxWorkerState):
state (WorkerState):
loss (torch.Tensor):
**kwargs ():
Expand Down
Loading

0 comments on commit 887b382

Please sign in to comment.