Skip to content

Commit

Permalink
Basic federation loop (demo.py)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathaniel-hudson committed Nov 12, 2024
1 parent 14e3dd1 commit 50468ac
Show file tree
Hide file tree
Showing 40 changed files with 878 additions and 155 deletions.
63 changes: 63 additions & 0 deletions #main.py#
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import os

import torch
from torch import nn
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

from flox import federated_fit
from flox.federation.topologies import Topology
from flox.learn import FloxModule
from flox.learn.data.utils import federated_split


class MyModule(FloxModule):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)
self.last_accuracy = torch.tensor([1.0])

def forward(self, x):
x = self.flatten(x)
logits = self.linear_stack(x)
return logits

def training_step(self, batch, batch_idx):
inputs, targets = batch
preds = self(inputs)
loss = nn.functional.cross_entropy(preds, targets)
return loss

def configure_optimizers(self) -> torch.optim.Optimizer:
return torch.optim.SGD(self.parameters(), lr=1e-3)


def load_data():
return MNIST(
root=os.environ["TORCH_DATASETS"],
download=False,
train=False,
transform=ToTensor(),
)


if __name__ == "__main__":
flock = Topology.from_yaml("examples/flocks/3-tier.yaml")
fed_data = federated_split(
load_data(),
flock,
10,
samples_alpha=10.0,
labels_alpha=10.0,
)
module, train_history = federated_fit(
flock, MyModule(), fed_data, 2, strategy="fedavg", launcher_kind="thread"
)
train_history.to_csv("temp.csv")
174 changes: 174 additions & 0 deletions Demo.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"id": "a009e51d-4aef-4a19-b540-61c37cfe1814",
"metadata": {},
"outputs": [],
"source": [
"import logging\n",
"import os\n",
"\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torchmetrics\n",
"\n",
"from flight import federated_fit\n",
"from flight.data import federated_split, fed_barplot\n",
"from flight.topo import Topology\n",
"from flight.nn import FloxModule\n",
"\n",
"from torch.utils.data import Subset\n",
"from torchvision import transforms\n",
"from torchvision.datasets import FashionMNIST"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "95402f1a-ec50-49ab-afb4-a1f02550e1b0",
"metadata": {},
"outputs": [],
"source": [
"class SmallConvModel(FloxModule):\n",
" def __init__(self, lr: float = 0.01, device: str | None = None):\n",
" super().__init__()\n",
" self.lr = lr\n",
" self.conv1 = nn.Conv2d(1, 6, 5)\n",
" self.pool = nn.MaxPool2d(2, 2)\n",
" self.conv2 = nn.Conv2d(6, 16, 5)\n",
" self.fc1 = nn.Linear(256, 120)\n",
" self.fc2 = nn.Linear(120, 84)\n",
" self.fc3 = nn.Linear(84, 10)\n",
" self.accuracy = torchmetrics.Accuracy(task=\"multiclass\", num_classes=10)\n",
" self.last_accuracy = None\n",
"\n",
" def forward(self, x):\n",
" x = self.pool(F.relu(self.conv1(x)))\n",
" x = self.pool(F.relu(self.conv2(x)))\n",
" x = torch.flatten(x, 1)\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" x = self.fc3(x)\n",
" return x\n",
"\n",
" def training_step(self, batch, batch_idx):\n",
" inputs, targets = batch\n",
" preds = self.forward(inputs)\n",
" loss = F.cross_entropy(preds, targets)\n",
" self.last_accuracy = self.accuracy(preds, targets)\n",
" return loss\n",
"\n",
" def configure_optimizers(self) -> torch.optim.Optimizer:\n",
" return torch.optim.SGD(self.parameters(), lr=self.lr)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "2af5a207-1a58-4157-8c9b-b49d6818cd71",
"metadata": {},
"outputs": [
{
"ename": "IndexError",
"evalue": "list index out of range",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[14], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m topo \u001b[38;5;241m=\u001b[39m Topology\u001b[38;5;241m.\u001b[39mfrom_yaml(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflock.yaml\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mtopo\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdraw\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprog\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtwopi\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m plt\u001b[38;5;241m.\u001b[39mshow()\n",
"File \u001b[0;32m~/Development/MiscProjects/flox_demo/venv/lib/python3.12/site-packages/flight/topo/topology.py:174\u001b[0m, in \u001b[0;36mTopology.draw\u001b[0;34m(self, color_by_kind, with_labels, label_color, prog, node_kind_attrs, show_axis_border, ax)\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;66;03m# TODO: We may want to remove this as a requirement. It produces nice \"tree\" positions\u001b[39;00m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;66;03m# of the nodes. But it introduces a pretty restrictive dependency.\u001b[39;00m\n\u001b[1;32m 173\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m prog \u001b[38;5;129;01min\u001b[39;00m PROGS:\n\u001b[0;32m--> 174\u001b[0m pos \u001b[38;5;241m=\u001b[39m \u001b[43mnx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnx_pydot\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpydot_layout\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtopo\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprog\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprog\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 175\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 176\u001b[0m pos \u001b[38;5;241m=\u001b[39m nx\u001b[38;5;241m.\u001b[39mspring_layout(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtopo)\n",
"File \u001b[0;32m~/Development/MiscProjects/flox_demo/venv/lib/python3.12/site-packages/networkx/drawing/nx_pydot.py:406\u001b[0m, in \u001b[0;36mpydot_layout\u001b[0;34m(G, prog, root)\u001b[0m\n\u001b[1;32m 403\u001b[0m node \u001b[38;5;241m=\u001b[39m Q\u001b[38;5;241m.\u001b[39mget_node(pydot_node)\n\u001b[1;32m 405\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(node, \u001b[38;5;28mlist\u001b[39m):\n\u001b[0;32m--> 406\u001b[0m node \u001b[38;5;241m=\u001b[39m \u001b[43mnode\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[1;32m 407\u001b[0m pos \u001b[38;5;241m=\u001b[39m node\u001b[38;5;241m.\u001b[39mget_pos()[\u001b[38;5;241m1\u001b[39m:\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;66;03m# strip leading and trailing double quotes\u001b[39;00m\n\u001b[1;32m 408\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m pos \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
"\u001b[0;31mIndexError\u001b[0m: list index out of range"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAGFCAYAAABg2vAPAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAGYUlEQVR4nO3WMQEAIAzAMMC/5yFjRxMFPXtnZg4AkPW2AwCAXWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiPsF9wcGCbd4pQAAAABJRU5ErkJggg==",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"topo = Topology.from_yaml(\"flock.yaml\")\n",
"topo.draw(prog=\"twopi\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "e0a4553b-97dd-4800-be91-07bf6b01892b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Dataset FashionMNIST\n",
" Number of datapoints: 60000\n",
" Root location: /Users/Nathaniel/Research/Data/Torch-Data/\n",
" Split: Train\n",
" StandardTransform\n",
"Transform: Compose(\n",
" ToTensor()\n",
" Normalize(mean=0.5, std=0.5)\n",
" )"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"root_dir = \"~/Research/Data/Torch-Data/\"\n",
"data = FashionMNIST(\n",
" root=root_dir,\n",
" train=True,\n",
" download=False,\n",
" transform=transforms.Compose(\n",
" [transforms.ToTensor(), transforms.Normalize(0.5, 0.5)]\n",
" ),\n",
")\n",
"data = Subset(data, indices=list(range(1000)))\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "092bc059-094c-4f62-80bf-122899e32a47",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
177 changes: 158 additions & 19 deletions Temp Notebook.ipynb

Large diffs are not rendered by default.

51 changes: 51 additions & 0 deletions demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset

import flight as fl
from flight.learning import federated_split
from flight.learning.torch import TorchModule
from flight.learning.torch.types import TensorLoss

NUM_LABELS = 10


class TrainingModule(TorchModule):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(1, 10),
nn.Linear(10, 100),
nn.Linear(100, NUM_LABELS),
)

def forward(self, x):
return self.model(x)

def training_step(self, batch) -> TensorLoss:
x, y = batch
y_hat = self(x)
return nn.functional.nll_loss(y_hat, y)

def configure_optimizers(self) -> torch.optim.Optimizer:
return torch.optim.Adam(self.parameters(), lr=0.02)


def main():
topo = fl.flat_topology(10)
module = TrainingModule()
fed_data = federated_split(
topo=topo,
data=TensorDataset(
torch.randn(100, 1), torch.randint(low=0, high=NUM_LABELS, size=(100, 1))
),
num_labels=NUM_LABELS,
label_alpha=100.0,
sample_alpha=100.0,
)
trained_module, records = fl.federated_fit(topo, module, fed_data, rounds=2)
print(records)


if __name__ == "__main__":
main()
5 changes: 2 additions & 3 deletions depr/flox/federation/fed_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
from copy import deepcopy

import pandas as pd
from tqdm import tqdm

from flox.federation.fed import Federation
from flox.logger import Logger
from tqdm import tqdm

if t.TYPE_CHECKING:
from pandas import DataFrame
Expand Down Expand Up @@ -92,7 +91,7 @@ def start(self, debug_mode: bool = False) -> tuple[FloxModule, DataFrame]:

self.global_model = DebugModule()
self.params = self.global_model.state_dict()
self.state.global_model = self.global_model
self.state.pre_module = self.global_model

if not self.topo.two_tier:
raise ValueError(
Expand Down
4 changes: 2 additions & 2 deletions depr/flox/federation/jobs/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def __call__(
history["parent/idx"] = parent.idx
history["parent/kind"] = parent.kind.to_str()

assert state.local_model is not None
local_params = state.local_model.state_dict()
assert state.module is not None
local_params = state.module.state_dict()
result = JobResult(state, node.idx, node.kind, local_params, history)

result = worker_strategy.work_end(result) # NOTE: Double-check.
Expand Down
2 changes: 1 addition & 1 deletion depr/flox/strategies/impl/fedasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def aggregate_params(
assert last_updated is not None
assert isinstance(last_updated, int | str)

global_model_params = state.global_model.state_dict()
global_model_params = state.pre_module.state_dict()
last_updated_params = children_state_dicts[last_updated]

aggr_params = []
Expand Down
5 changes: 2 additions & 3 deletions depr/flox/strategies/impl/fedprox.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import typing as t

import torch

from flox.strategies import Strategy
from flox.strategies.impl.fedavg import FedAvgWorker
from flox.strategies.impl.fedsgd import FedSGDAggr, FedSGDClient
Expand Down Expand Up @@ -38,8 +37,8 @@ def before_backprop(self, state: WorkerState, loss: Loss) -> Loss:
Returns:
Loss with the proximal term added to it.
"""
global_model = state.global_model
local_model = state.local_model
global_model = state.pre_module
local_model = state.module
assert global_model is not None
assert local_model is not None

Expand Down
24 changes: 15 additions & 9 deletions docs/getting_started/strategies/custom.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
# Defining Your Own Custom Strategies

FLoX was designed with customizability in mind. FL is a new research area that invites countless questions about how to
best perform FL. Additionally, the best FL approach will vary depending on the data, network connectivity, other
requirements, etc. As such, we aimed to make defining original Strategies to be as pain-free as possible.

Implementing a custom ``Strategy`` simply requires defining a new class that extends/subclasses the ``Strategy`` protocol
(as seen above). The ``Strategy`` protocol provides a handful of callbacks for you to inject custom logic to adjust how the
FLoX was designed with customizability in mind. FL is a new research area that invites
countless questions about how to
best perform FL. Additionally, the best FL approach will vary depending on the data,
network connectivity, other
requirements, etc. As such, we aimed to make defining original Strategies to be as
pain-free as possible.

Implementing a custom ``Strategy`` simply requires defining a new class that
extends/subclasses the ``Strategy`` protocol
(as seen above). The ``Strategy`` protocol provides a handful of callbacks for you to
inject custom logic to adjust how the
FL process runs.

As an example, let's use our source code for the implementation of ``FedProx`` as an example.
As an example, let's use our source code for the implementation of ``FedProx`` as an
example.

```python
class FedProx(FedAvg):
Expand Down Expand Up @@ -38,8 +44,8 @@ class FedProx(FedAvg):
**kwargs,
) -> torch.Tensor:
"""..."""
global_model = state.global_model
local_model = state.local_model
global_model = state.pre_module
local_model = state.module

params = list(local_model.params().values())
params0 = list(global_model.params().values())
Expand Down
6 changes: 6 additions & 0 deletions flight/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,9 @@
**F**ederated **L**earning **I**n **G**eneral **H**ierarchical **T**opologies (FLIGHT)
is a lightweight framework for federated learning workflows for complex systems.
"""

from flight.federation.topologies import Topology
from flight.federation.topologies.utils import flat_topology
from flight.fit import federated_fit

__all__ = ["Topology", "flat_topology", "federated_fit"]
Loading

0 comments on commit 50468ac

Please sign in to comment.