-
Notifications
You must be signed in to change notification settings - Fork 207
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1b586cb
commit c73dd09
Showing
1 changed file
with
343 additions
and
0 deletions.
There are no files selected for viewing
343 changes: 343 additions & 0 deletions
343
openfl-tutorials/Federated_FedProx_PyTorch_MNIST_Workflow_Tutorial.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,343 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Federated FedProx PyTorch MNIST Workflow API Tutorial\n", | ||
"This notebook mimics Federated_FedProx_PyTorch_MNIST_Tutorial using the workflow API.\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Import the relevant libraries, choose an aggregation algorithm an optimizer and a loss function:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"import torch.nn.functional as F\n", | ||
"\n", | ||
"import torch.utils\n", | ||
"import torch.utils.data\n", | ||
"import torchvision\n", | ||
"import torchvision.transforms as transforms\n", | ||
"\n", | ||
"from openfl.utilities.optimizers.torch.fedprox import FedProxAdam\n", | ||
"\n", | ||
"from openfl.experimental.interface import FLSpec, Aggregator, Collaborator\n", | ||
"from openfl.experimental.runtime import LocalRuntime\n", | ||
"from openfl.experimental.placement import aggregator, collaborator\n", | ||
"\n", | ||
"def one_hot(labels, classes):\n", | ||
" return np.eye(classes)[labels]\n", | ||
"\n", | ||
"def cross_entropy(output, target):\n", | ||
" \"\"\"Binary cross-entropy loss function\"\"\"\n", | ||
" return F.binary_cross_entropy_with_logits(input=output,target=target.float())\n", | ||
"\n", | ||
"def get_optimizer(model):\n", | ||
" return FedProxAdam(model.parameters(), lr=1e-3, mu=0.01)\n", | ||
"\n", | ||
"# Aggregation algorithm\n", | ||
"def FedAvg(models, weights=None):\n", | ||
" new_model = models[0]\n", | ||
" new_state_dict = dict()\n", | ||
" for key in new_model.state_dict().keys():\n", | ||
" new_state_dict[key] = torch.from_numpy(np.average([model.state_dict()[key].numpy() for model in models],\n", | ||
" axis=0, \n", | ||
" weights=weights))\n", | ||
"\n", | ||
" new_model.load_state_dict(new_state_dict)\n", | ||
" return new_model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Define the model:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class Net(nn.Module):\n", | ||
" def __init__(self):\n", | ||
" super(Net, self).__init__()\n", | ||
" self.conv1 = nn.Conv2d(1, 16, 3)\n", | ||
" self.pool = nn.MaxPool2d(2, 2)\n", | ||
" self.conv2 = nn.Conv2d(16, 32, 3)\n", | ||
" self.fc1 = nn.Linear(32 * 5 * 5, 32)\n", | ||
" self.fc2 = nn.Linear(32, 84)\n", | ||
" self.fc3 = nn.Linear(84, 10)\n", | ||
"\n", | ||
" def forward(self, x):\n", | ||
" x = self.pool(F.relu(self.conv1(x)))\n", | ||
" x = self.pool(F.relu(self.conv2(x)))\n", | ||
" x = x.view(x.size(0),-1)\n", | ||
" x = F.relu(self.fc1(x))\n", | ||
" x = F.relu(self.fc2(x))\n", | ||
" x = self.fc3(x)\n", | ||
" return F.log_softmax(x, dim=1)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Set up the dataset:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"transform = transforms.Compose(\n", | ||
" [transforms.ToTensor(),\n", | ||
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n", | ||
"\n", | ||
"mnist_train = torchvision.datasets.MNIST(\n", | ||
" \"./files/\",\n", | ||
" train=True,\n", | ||
" download=True,\n", | ||
" transform=transform,\n", | ||
")\n", | ||
"\n", | ||
"mnist_test = torchvision.datasets.MNIST(\n", | ||
" \"./files/\",\n", | ||
" train=False,\n", | ||
" download=True,\n", | ||
" transform=transform,\n", | ||
")\n", | ||
"\n", | ||
"class CustomDataset(torch.utils.data.Dataset):\n", | ||
" \"\"\"Dataset enumeration as tensors\"\"\"\n", | ||
" def __init__(self, images, labels):\n", | ||
" self.images = images\n", | ||
" self.labels = labels\n", | ||
"\n", | ||
" def __len__(self):\n", | ||
" return len(self.images)\n", | ||
"\n", | ||
" def __getitem__(self, idx):\n", | ||
" image = self.images[idx]\n", | ||
" label = self.labels[idx]\n", | ||
" return image, label" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The next step is setting up the participants, an `Aggregator` and a few `Collaborator`s which will train the model, partition the dataset between the collaborators, and pass them to the appropriate runtime environment (in our case, a `LocalRuntime`).\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Setup participants\n", | ||
"aggregator_ = Aggregator()\n", | ||
"aggregator_.private_attributes = {}\n", | ||
"\n", | ||
"# Setup collaborators with private attributes\n", | ||
"collaborator_names = [f'collaborator{i}' for i in range(4)]\n", | ||
"collaborators = [Collaborator(name=name) for name in collaborator_names]\n", | ||
"batch_size_train = 1024\n", | ||
"batch_size_test = 1024\n", | ||
"log_interval = 10\n", | ||
"\n", | ||
"for idx, collaborator_ in enumerate(collaborators):\n", | ||
" train_images, train_labels = mnist_train.train_data, np.array(mnist_train.train_labels)\n", | ||
" train_images = torch.from_numpy(np.expand_dims(train_images, axis=1)).float()\n", | ||
" train_labels = one_hot(train_labels, 10)\n", | ||
"\n", | ||
" valid_images, valid_labels = mnist_test.test_data, np.array(mnist_test.test_labels)\n", | ||
" valid_images = torch.from_numpy(np.expand_dims(valid_images, axis=1)).float()\n", | ||
"\n", | ||
" collaborator_.private_attributes = {\n", | ||
" 'train_loader': torch.utils.data.DataLoader(\n", | ||
" CustomDataset(train_images[idx::len(collaborators)], \n", | ||
" train_labels[idx::len(collaborators)]), \n", | ||
" batch_size=batch_size_train, \n", | ||
" shuffle=True),\n", | ||
" 'test_loader': torch.utils.data.DataLoader(\n", | ||
" CustomDataset(valid_images[idx::len(collaborators)], \n", | ||
" valid_labels[idx::len(collaborators)]), \n", | ||
" batch_size=batch_size_test, \n", | ||
" shuffle=True)\n", | ||
" }\n", | ||
"\n", | ||
"local_runtime = LocalRuntime(aggregator=aggregator_, collaborators=collaborators, backend='single_process')\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Set up work to be executed by the aggregator and the collaborators by extending `FLSpec`:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class FederatedFlow(FLSpec):\n", | ||
" def __init__(self, model=None, optimizer=None, rounds=10, **kwargs):\n", | ||
" super().__init__(**kwargs)\n", | ||
" self.model = model\n", | ||
" self.optimizer = optimizer\n", | ||
" self.rounds = rounds\n", | ||
" self.loss = 0.\n", | ||
"\n", | ||
" @aggregator\n", | ||
" def start(self):\n", | ||
" print(f'Performing initialization for model')\n", | ||
" self.collaborators = self.runtime.collaborators\n", | ||
" self.private = 10\n", | ||
" self.current_round = 0\n", | ||
" self.next(self.aggregated_model_validation, foreach='collaborators', exclude=['private'])\n", | ||
"\n", | ||
" def compute_accuracy(self, data_loader):\n", | ||
" self.model.eval()\n", | ||
" test_loss = 0\n", | ||
" correct = 0\n", | ||
" with torch.no_grad():\n", | ||
" for data, target in data_loader:\n", | ||
" output = self.model(data)\n", | ||
" test_loss += F.cross_entropy(output, target, size_average=False).item()\n", | ||
" pred = output.data.max(1, keepdim=True)[1]\n", | ||
" correct += pred.eq(target.data.view_as(pred)).sum()\n", | ||
"\n", | ||
" test_loss /= len(data_loader.dataset)\n", | ||
" print('\\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", | ||
" test_loss, correct, len(data_loader.dataset),\n", | ||
" 100. * correct / len(data_loader.dataset)))\n", | ||
" accuracy = float(correct / len(data_loader.dataset))\n", | ||
" return accuracy\n", | ||
"\n", | ||
" @collaborator\n", | ||
" def aggregated_model_validation(self):\n", | ||
" print(f'Performing aggregated model validation for collaborator {self.input}, model: {id(self.model)}')\n", | ||
" self.agg_validation_score = self.compute_accuracy(self.test_loader)\n", | ||
" self.next(self.train)\n", | ||
"\n", | ||
" @collaborator\n", | ||
" def train(self):\n", | ||
" # Log after processing a quarter of the samples\n", | ||
" log_threshold = .25\n", | ||
"\n", | ||
" self.model.train()\n", | ||
" self.optimizer = get_optimizer(self.model)\n", | ||
" for batch_idx, (data, target) in enumerate(self.train_loader):\n", | ||
" self.optimizer.zero_grad()\n", | ||
" output = self.model(data)\n", | ||
" loss = F.cross_entropy(output, target)\n", | ||
" loss.backward()\n", | ||
" self.optimizer.step()\n", | ||
"\n", | ||
" if (len(data) * batch_idx) / len(self.train_loader.dataset) >= log_threshold:\n", | ||
" print('Train Epoch: [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", | ||
" batch_idx * len(data), len(self.train_loader.dataset),\n", | ||
" 100. * batch_idx / len(self.train_loader), loss.item()))\n", | ||
" self.loss = loss.item()\n", | ||
" log_threshold += .25\n", | ||
" torch.save(self.model.state_dict(), 'model.pth')\n", | ||
" torch.save(self.optimizer.state_dict(), 'optimizer.pth')\n", | ||
" \n", | ||
" self.next(self.local_model_validation)\n", | ||
"\n", | ||
" @collaborator\n", | ||
" def local_model_validation(self):\n", | ||
" print(f'Performing local model validation for collaborator {self.input}')\n", | ||
" self.local_validation_score = self.compute_accuracy(self.test_loader)\n", | ||
" print(\n", | ||
" f'Done with local model validation for collaborator {self.input}, Accuracy: {self.local_validation_score}')\n", | ||
" self.next(self.join)\n", | ||
"\n", | ||
" @aggregator\n", | ||
" def join(self, inputs):\n", | ||
" self.model = FedAvg([input.model for input in inputs])\n", | ||
" self.optimizer = inputs[0].optimizer\n", | ||
" self.current_round += 1\n", | ||
"\n", | ||
" self.average_loss = sum(input.loss for input in inputs) / len(inputs)\n", | ||
" self.aggregated_model_accuracy = sum(\n", | ||
" input.agg_validation_score for input in inputs) / len(inputs)\n", | ||
" self.local_model_accuracy = sum(\n", | ||
" input.local_validation_score for input in inputs) / len(inputs)\n", | ||
" print(f'Average aggregated model accuracy = {self.aggregated_model_accuracy}')\n", | ||
" print(f'Average training loss = {self.average_loss}')\n", | ||
" print(f'Average local model validation values = {self.local_model_accuracy}')\n", | ||
"\n", | ||
" if self.current_round < self.rounds:\n", | ||
" self.next(self.aggregated_model_validation,\n", | ||
" foreach='collaborators', exclude=['private'])\n", | ||
" else:\n", | ||
" self.next(self.end)\n", | ||
"\n", | ||
" @aggregator\n", | ||
" def end(self):\n", | ||
" print(f'Flow ended')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Finally, run the federation:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"model = Net()\n", | ||
"flflow = FederatedFlow(model, get_optimizer(model), rounds=3, checkpoint=False)\n", | ||
"flflow.runtime = local_runtime\n", | ||
"flflow.run()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.10" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |