Skip to content

Commit

Permalink
Working refactored strategies
Browse files Browse the repository at this point in the history
TODO: resolve `mypy` issues
  • Loading branch information
nathaniel-hudson committed Mar 10, 2024
1 parent 730c06e commit 0942f7c
Show file tree
Hide file tree
Showing 44 changed files with 665 additions and 823 deletions.
1 change: 0 additions & 1 deletion fashion_mnist_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def main():
# where="local", # "globus_compute",
)
df.to_feather(Path("out/fashion_mnist_demo.feather"))
print(">>> Finished!")


if __name__ == "__main__":
Expand Down
16 changes: 8 additions & 8 deletions flox/flock/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,22 +76,22 @@ def __init__(self, idx: NodeID):
class WorkerState(NodeState):
"""State of a Worker node in a ``Flock``."""

pre_local_train_model: FloxModule | None = None
global_model: FloxModule | None = None
"""Global model."""

post_local_train_model: FloxModule | None = None
local_model: FloxModule | None = None
"""Local model after local fitting/training."""

def __init__(
self,
idx: NodeID,
pre_local_train_model: FloxModule | None = None,
post_local_train_model: FloxModule | None = None,
global_model: FloxModule | None = None,
local_model: FloxModule | None = None,
):
super().__init__(idx)
self.pre_local_train_model = pre_local_train_model
self.post_local_train_model = post_local_train_model
self.global_model = global_model
self.local_model = local_model

def __repr__(self) -> str:
template = "WorkerState(pre_local_train_model={}, post_local_train_model={})"
return template.format(self.pre_local_train_model, self.post_local_train_model)
template = "WorkerState(global_model={}, local_model={})"
return template.format(self.global_model, self.local_model)
4 changes: 2 additions & 2 deletions flox/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Job: t.TypeAlias = AggregableJob | TrainableJob | NodeCallable
"""
An umbrella typing that encapsulates both ``AggregableJob`` and ``TrainableJob`` protocols
for job implementations for both the aggregator and worker nodes (respectively).
for job impl for both the aggregator and worker nodes (respectively).
"""


Expand All @@ -27,7 +27,7 @@
"AggregableJob",
"TrainableJob",
"NodeCallable",
# Job implementations.
# Job impl.
"AggregateJob",
"DebugAggregateJob",
"LocalTrainJob",
Expand Down
16 changes: 8 additions & 8 deletions flox/jobs/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@
from flox.jobs.protocols import AggregableJob
from flox.runtime.result import Result
from flox.runtime.transfer import BaseTransfer
from flox.strategies_depr import Strategy
from flox.strategies import AggregatorStrategy


class AggregateJob(AggregableJob):
@staticmethod
def __call__(
node: FlockNode,
transfer: BaseTransfer,
strategy: Strategy,
aggr_strategy: AggregatorStrategy,
results: list[Result],
) -> Result:
"""Aggregate the state dicts from each of the results.
Args:
node (FlockNode): The aggregator node.
transfer (Transfer): ...
strategy (Strategy): ...
aggr_strategy (AggregatorStrategy): ...
results (list[JobResult]): Results from children of ``node``.
Returns:
Expand All @@ -33,10 +33,10 @@ def __call__(
for result in results:
idx: NodeID = result.node_idx
child_states[idx] = result.node_state
child_state_dicts[idx] = result.state_dict
child_state_dicts[idx] = result.params

node_state = AggrState(node.idx)
avg_state_dict = strategy.agg_param_aggregation(
avg_state_dict = aggr_strategy.aggregate_params(
node_state, child_states, child_state_dicts
)

Expand Down Expand Up @@ -64,15 +64,15 @@ class DebugAggregateJob(AggregableJob):
def __call__(
node: FlockNode,
transfer: BaseTransfer,
strategy: Strategy,
aggr_strategy: AggregatorStrategy,
results: list[Result],
) -> Result:
"""
Args:
node ():
transfer ():
strategy ():
aggr_strategy ():
results ():
Returns:
Expand All @@ -85,7 +85,7 @@ def __call__(
from flox.runtime import JobResult

result = next(iter(results))
state_dict = result.state_dict
state_dict = result.params
state_dict = {} if state_dict is None else state_dict
node_state = AggrState(node.idx)
history = {
Expand Down
52 changes: 28 additions & 24 deletions flox/jobs/local_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from flox.data import FloxDataset
from flox.flock import FlockNode
from flox.nn import FloxModule
from flox.nn.typing import StateDict
from flox.nn.typing import Params
from flox.runtime import Result
from flox.runtime.transfer import BaseTransfer
from flox.strategies import WorkerStrategy, TrainerStrategy
Expand All @@ -19,8 +19,8 @@ class LocalTrainJob(TrainableJob):
def __call__(
node: FlockNode,
parent: FlockNode,
module: FloxModule,
module_state_dict: StateDict,
global_model: FloxModule,
module_state_dict: Params,
dataset: FloxDataset,
transfer: BaseTransfer,
worker_strategy: WorkerStrategy,
Expand All @@ -35,7 +35,7 @@ def __call__(
parent (FlockNode):
strategy (Strategy):
module (FloxModule):
module_state_dict (StateDict):
module_state_dict (Params):
dataset (Dataset | Subset | None):
**train_hyper_params ():
Expand All @@ -44,53 +44,57 @@ def __call__(
"""
from copy import deepcopy
from flox.flock.states import WorkerState
from flox.nn.trainer import Trainer
from flox.nn.model_trainer import Trainer
from torch.utils.data import DataLoader
from flox.runtime import JobResult

global_model = module
global_state_dict = module.state_dict()
local_model = deepcopy(module)
# global_state_dict = global_model.state_dict()
local_model = deepcopy(global_model)
global_model.load_state_dict(module_state_dict)
local_model.load_state_dict(module_state_dict)

node_state = WorkerState(
state = WorkerState(
node.idx,
pre_local_train_model=global_model,
post_local_train_model=local_model,
global_model=global_model,
local_model=local_model,
)
state = worker_strategy.work_start(state) # NOTE: Double-check.

worker_strategy.work_start()
data = dataset.load(node)
train_dataloader = DataLoader(
data,
batch_size=train_hyper_params.get("batch_size", 32),
shuffle=train_hyper_params.get("shuffle", True),
)

# Add optimizer to this strategy.
worker_strategy.before_training(node_state, data)
trainer = Trainer()
optimizer = local_model.configure_optimizer()
trainer = Trainer(trainer_strategy)
optimizer = local_model.configure_optimizers()

state, data = worker_strategy.before_training(state, data)
history = trainer.fit(
local_model,
optimizer,
train_dataloader,
# TODO: Include `trainer_params` as an argument to
# this so users can easily customize Trainer.
num_epochs=train_hyper_params.get("num_epochs", 2),
node_state=node_state,
trainer_strategy=trainer_strategy,
node_state=state,
)

local_params = worker_strategy.after_training(node_state)
state = worker_strategy.after_training(state) # NOTE: Double-check.

################################################################################
# TRAINING DATA POST-PROCESSING
################################################################################
history["node/idx"] = node.idx
history["node/kind"] = node.kind.to_str()
history["parent/idx"] = parent.idx
history["parent/kind"] = parent.kind.to_str()

result = JobResult(node_state, node.idx, node.kind, local_params, history)
local_params = state.local_model.state_dict()
result = JobResult(state, node.idx, node.kind, local_params, history)

result = worker_strategy.work_end(result) # NOTE: Double-check.
return transfer.report(result)


Expand All @@ -99,8 +103,8 @@ class DebugLocalTrainJob(TrainableJob):
def __call__(
node: FlockNode,
parent: FlockNode,
module: FloxModule,
module_state_dict: StateDict,
global_model: FloxModule,
module_state_dict: Params,
dataset: FloxDataset,
transfer: BaseTransfer,
worker_strategy: WorkerStrategy,
Expand Down Expand Up @@ -128,8 +132,8 @@ def __call__(
local_module = module
node_state = WorkerState(
node.idx,
pre_local_train_model=local_module,
post_local_train_model=local_module,
global_model=local_module,
local_model=local_module,
)
history = {
"node/idx": [node.idx],
Expand Down
27 changes: 15 additions & 12 deletions flox/jobs/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
1. aggregation jobs (``AggregableJob``)
2. local training jobs (``TrainableJob``)
These protocols can be used to define custom implementations of aggregation jobs for highly-customized FLoX processes.
These protocols can be used to define custom impl of aggregation jobs for highly-customized FLoX processes.
However, this is not necessary for the vast majority of imaginable cases.
Should users choose to do this, it is up to the user's discretion to do so safely and correctly.
Expand All @@ -27,11 +27,14 @@
from flox.data import FloxDataset
from flox.flock import FlockNode
from flox.nn import FloxModule
from flox.nn.typing import StateDict
from flox.nn.typing import Params
from flox.runtime import Result
from flox.runtime.transfer import BaseTransfer
from flox.strategies import WorkerStrategy, TrainerStrategy
from flox.strategies_depr import Strategy
from flox.strategies import (
WorkerStrategy,
TrainerStrategy,
AggregatorStrategy,
)


class NodeCallable(t.Protocol):
Expand All @@ -46,11 +49,11 @@ def __call__(self, node: FlockNode, *args, **kwargs) -> t.Any:
@t.runtime_checkable
class AggregableJob(t.Protocol):
"""
A protocol that defines functions that are valid implementations to be used for model aggregation in
A protocol that defines functions that are valid impl to be used for model aggregation in
launching FLoX processes.
Notes:
FLoX provides default implementations of this protocol via
FLoX provides default impl of this protocol via
[AggregateJob][flox.jobs.aggregation.AggregateJob] and
[DebugAggregateJob][flox.jobs.aggregation.DebugAggregateJob].
"""
Expand All @@ -59,7 +62,7 @@ class AggregableJob(t.Protocol):
def __call__(
node: FlockNode,
transfer: BaseTransfer,
strategy: Strategy,
aggr_strategy: AggregatorStrategy,
results: list[Result],
) -> Result:
"""
Expand All @@ -68,7 +71,7 @@ def __call__(
Args:
node (FlockNode):
transfer (BaseTransfer):
strategy (Strategy):
aggr_strategy (AggregatorStrategy):
results (list[Result]):
Returns:
Expand All @@ -79,11 +82,11 @@ def __call__(
@t.runtime_checkable
class TrainableJob(t.Protocol):
"""
A protocol that defines functions that are valid implementations to be used for local training in
A protocol that defines functions that are valid impl to be used for local training in
launching FLoX processes.
Notes:
FLoX provides default implementations of this protocol via
FLoX provides default impl of this protocol via
[LocalTrainJob][flox.jobs.local_training.LocalTrainJob] and
[DebugLocalTrainJob][flox.jobs.local_training.DebugLocalTrainJob].
"""
Expand All @@ -92,8 +95,8 @@ class TrainableJob(t.Protocol):
def __call__(
node: FlockNode,
parent: FlockNode,
module: FloxModule,
module_state_dict: StateDict,
global_model: FloxModule,
module_state_dict: Params,
dataset: FloxDataset,
transfer: BaseTransfer,
worker_strategy: WorkerStrategy,
Expand Down
Loading

0 comments on commit 0942f7c

Please sign in to comment.