-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
14e3dd1
commit 50468ac
Showing
40 changed files
with
878 additions
and
155 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import os | ||
|
||
import torch | ||
from torch import nn | ||
from torchvision.datasets import MNIST | ||
from torchvision.transforms import ToTensor | ||
|
||
from flox import federated_fit | ||
from flox.federation.topologies import Topology | ||
from flox.learn import FloxModule | ||
from flox.learn.data.utils import federated_split | ||
|
||
|
||
class MyModule(FloxModule): | ||
def __init__(self): | ||
super().__init__() | ||
self.flatten = nn.Flatten() | ||
self.linear_stack = nn.Sequential( | ||
nn.Linear(28 * 28, 512), | ||
nn.ReLU(), | ||
nn.Linear(512, 512), | ||
nn.ReLU(), | ||
nn.Linear(512, 10), | ||
) | ||
self.last_accuracy = torch.tensor([1.0]) | ||
|
||
def forward(self, x): | ||
x = self.flatten(x) | ||
logits = self.linear_stack(x) | ||
return logits | ||
|
||
def training_step(self, batch, batch_idx): | ||
inputs, targets = batch | ||
preds = self(inputs) | ||
loss = nn.functional.cross_entropy(preds, targets) | ||
return loss | ||
|
||
def configure_optimizers(self) -> torch.optim.Optimizer: | ||
return torch.optim.SGD(self.parameters(), lr=1e-3) | ||
|
||
|
||
def load_data(): | ||
return MNIST( | ||
root=os.environ["TORCH_DATASETS"], | ||
download=False, | ||
train=False, | ||
transform=ToTensor(), | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
flock = Topology.from_yaml("examples/flocks/3-tier.yaml") | ||
fed_data = federated_split( | ||
load_data(), | ||
flock, | ||
10, | ||
samples_alpha=10.0, | ||
labels_alpha=10.0, | ||
) | ||
module, train_history = federated_fit( | ||
flock, MyModule(), fed_data, 2, strategy="fedavg", launcher_kind="thread" | ||
) | ||
train_history.to_csv("temp.csv") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 11, | ||
"id": "a009e51d-4aef-4a19-b540-61c37cfe1814", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import logging\n", | ||
"import os\n", | ||
"\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"import torch.nn.functional as F\n", | ||
"import torchmetrics\n", | ||
"\n", | ||
"from flight import federated_fit\n", | ||
"from flight.data import federated_split, fed_barplot\n", | ||
"from flight.topo import Topology\n", | ||
"from flight.nn import FloxModule\n", | ||
"\n", | ||
"from torch.utils.data import Subset\n", | ||
"from torchvision import transforms\n", | ||
"from torchvision.datasets import FashionMNIST" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "95402f1a-ec50-49ab-afb4-a1f02550e1b0", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class SmallConvModel(FloxModule):\n", | ||
" def __init__(self, lr: float = 0.01, device: str | None = None):\n", | ||
" super().__init__()\n", | ||
" self.lr = lr\n", | ||
" self.conv1 = nn.Conv2d(1, 6, 5)\n", | ||
" self.pool = nn.MaxPool2d(2, 2)\n", | ||
" self.conv2 = nn.Conv2d(6, 16, 5)\n", | ||
" self.fc1 = nn.Linear(256, 120)\n", | ||
" self.fc2 = nn.Linear(120, 84)\n", | ||
" self.fc3 = nn.Linear(84, 10)\n", | ||
" self.accuracy = torchmetrics.Accuracy(task=\"multiclass\", num_classes=10)\n", | ||
" self.last_accuracy = None\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" x = self.pool(F.relu(self.conv1(x)))\n", | ||
" x = self.pool(F.relu(self.conv2(x)))\n", | ||
" x = torch.flatten(x, 1)\n", | ||
" x = F.relu(self.fc1(x))\n", | ||
" x = F.relu(self.fc2(x))\n", | ||
" x = self.fc3(x)\n", | ||
" return x\n", | ||
"\n", | ||
" def training_step(self, batch, batch_idx):\n", | ||
" inputs, targets = batch\n", | ||
" preds = self.forward(inputs)\n", | ||
" loss = F.cross_entropy(preds, targets)\n", | ||
" self.last_accuracy = self.accuracy(preds, targets)\n", | ||
" return loss\n", | ||
"\n", | ||
" def configure_optimizers(self) -> torch.optim.Optimizer:\n", | ||
" return torch.optim.SGD(self.parameters(), lr=self.lr)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"id": "2af5a207-1a58-4157-8c9b-b49d6818cd71", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"ename": "IndexError", | ||
"evalue": "list index out of range", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | ||
"\u001b[0;31mIndexError\u001b[0m Traceback (most recent call last)", | ||
"Cell \u001b[0;32mIn[14], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m topo \u001b[38;5;241m=\u001b[39m Topology\u001b[38;5;241m.\u001b[39mfrom_yaml(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mflock.yaml\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mtopo\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdraw\u001b[49m\u001b[43m(\u001b[49m\u001b[43mprog\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtwopi\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3\u001b[0m plt\u001b[38;5;241m.\u001b[39mshow()\n", | ||
"File \u001b[0;32m~/Development/MiscProjects/flox_demo/venv/lib/python3.12/site-packages/flight/topo/topology.py:174\u001b[0m, in \u001b[0;36mTopology.draw\u001b[0;34m(self, color_by_kind, with_labels, label_color, prog, node_kind_attrs, show_axis_border, ax)\u001b[0m\n\u001b[1;32m 171\u001b[0m \u001b[38;5;66;03m# TODO: We may want to remove this as a requirement. It produces nice \"tree\" positions\u001b[39;00m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;66;03m# of the nodes. But it introduces a pretty restrictive dependency.\u001b[39;00m\n\u001b[1;32m 173\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m prog \u001b[38;5;129;01min\u001b[39;00m PROGS:\n\u001b[0;32m--> 174\u001b[0m pos \u001b[38;5;241m=\u001b[39m \u001b[43mnx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnx_pydot\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpydot_layout\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtopo\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprog\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mprog\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 175\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 176\u001b[0m pos \u001b[38;5;241m=\u001b[39m nx\u001b[38;5;241m.\u001b[39mspring_layout(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtopo)\n", | ||
"File \u001b[0;32m~/Development/MiscProjects/flox_demo/venv/lib/python3.12/site-packages/networkx/drawing/nx_pydot.py:406\u001b[0m, in \u001b[0;36mpydot_layout\u001b[0;34m(G, prog, root)\u001b[0m\n\u001b[1;32m 403\u001b[0m node \u001b[38;5;241m=\u001b[39m Q\u001b[38;5;241m.\u001b[39mget_node(pydot_node)\n\u001b[1;32m 405\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(node, \u001b[38;5;28mlist\u001b[39m):\n\u001b[0;32m--> 406\u001b[0m node \u001b[38;5;241m=\u001b[39m \u001b[43mnode\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[1;32m 407\u001b[0m pos \u001b[38;5;241m=\u001b[39m node\u001b[38;5;241m.\u001b[39mget_pos()[\u001b[38;5;241m1\u001b[39m:\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m] \u001b[38;5;66;03m# strip leading and trailing double quotes\u001b[39;00m\n\u001b[1;32m 408\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m pos \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n", | ||
"\u001b[0;31mIndexError\u001b[0m: list index out of range" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAGFCAYAAABg2vAPAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAGYUlEQVR4nO3WMQEAIAzAMMC/5yFjRxMFPXtnZg4AkPW2AwCAXWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiDMDABBnBgAgzgwAQJwZAIA4MwAAcWYAAOLMAADEmQEAiPsF9wcGCbd4pQAAAABJRU5ErkJggg==", | ||
"text/plain": [ | ||
"<Figure size 640x480 with 1 Axes>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
} | ||
], | ||
"source": [ | ||
"topo = Topology.from_yaml(\"flock.yaml\")\n", | ||
"topo.draw(prog=\"twopi\")\n", | ||
"plt.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "e0a4553b-97dd-4800-be91-07bf6b01892b", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"Dataset FashionMNIST\n", | ||
" Number of datapoints: 60000\n", | ||
" Root location: /Users/Nathaniel/Research/Data/Torch-Data/\n", | ||
" Split: Train\n", | ||
" StandardTransform\n", | ||
"Transform: Compose(\n", | ||
" ToTensor()\n", | ||
" Normalize(mean=0.5, std=0.5)\n", | ||
" )" | ||
] | ||
}, | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"root_dir = \"~/Research/Data/Torch-Data/\"\n", | ||
"data = FashionMNIST(\n", | ||
" root=root_dir,\n", | ||
" train=True,\n", | ||
" download=False,\n", | ||
" transform=transforms.Compose(\n", | ||
" [transforms.ToTensor(), transforms.Normalize(0.5, 0.5)]\n", | ||
" ),\n", | ||
")\n", | ||
"data = Subset(data, indices=list(range(1000)))\n", | ||
"data" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "092bc059-094c-4f62-80bf-122899e32a47", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torch.utils.data import TensorDataset | ||
|
||
import flight as fl | ||
from flight.learning import federated_split | ||
from flight.learning.torch import TorchModule | ||
from flight.learning.torch.types import TensorLoss | ||
|
||
NUM_LABELS = 10 | ||
|
||
|
||
class TrainingModule(TorchModule): | ||
def __init__(self): | ||
super().__init__() | ||
self.model = nn.Sequential( | ||
nn.Linear(1, 10), | ||
nn.Linear(10, 100), | ||
nn.Linear(100, NUM_LABELS), | ||
) | ||
|
||
def forward(self, x): | ||
return self.model(x) | ||
|
||
def training_step(self, batch) -> TensorLoss: | ||
x, y = batch | ||
y_hat = self(x) | ||
return nn.functional.nll_loss(y_hat, y) | ||
|
||
def configure_optimizers(self) -> torch.optim.Optimizer: | ||
return torch.optim.Adam(self.parameters(), lr=0.02) | ||
|
||
|
||
def main(): | ||
topo = fl.flat_topology(10) | ||
module = TrainingModule() | ||
fed_data = federated_split( | ||
topo=topo, | ||
data=TensorDataset( | ||
torch.randn(100, 1), torch.randint(low=0, high=NUM_LABELS, size=(100, 1)) | ||
), | ||
num_labels=NUM_LABELS, | ||
label_alpha=100.0, | ||
sample_alpha=100.0, | ||
) | ||
trained_module, records = fl.federated_fit(topo, module, fed_data, rounds=2) | ||
print(records) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.