Skip to content

Commit

Permalink
Add pytorch MNIST Workflow tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
ynonflumintel committed Nov 20, 2024
1 parent 1b586cb commit c73dd09
Showing 1 changed file with 343 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,343 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Federated FedProx PyTorch MNIST Workflow API Tutorial\n",
"This notebook mimics Federated_FedProx_PyTorch_MNIST_Tutorial using the workflow API.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Import the relevant libraries, choose an aggregation algorithm an optimizer and a loss function:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"import torch.utils\n",
"import torch.utils.data\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"\n",
"from openfl.utilities.optimizers.torch.fedprox import FedProxAdam\n",
"\n",
"from openfl.experimental.interface import FLSpec, Aggregator, Collaborator\n",
"from openfl.experimental.runtime import LocalRuntime\n",
"from openfl.experimental.placement import aggregator, collaborator\n",
"\n",
"def one_hot(labels, classes):\n",
" return np.eye(classes)[labels]\n",
"\n",
"def cross_entropy(output, target):\n",
" \"\"\"Binary cross-entropy loss function\"\"\"\n",
" return F.binary_cross_entropy_with_logits(input=output,target=target.float())\n",
"\n",
"def get_optimizer(model):\n",
" return FedProxAdam(model.parameters(), lr=1e-3, mu=0.01)\n",
"\n",
"# Aggregation algorithm\n",
"def FedAvg(models, weights=None):\n",
" new_model = models[0]\n",
" new_state_dict = dict()\n",
" for key in new_model.state_dict().keys():\n",
" new_state_dict[key] = torch.from_numpy(np.average([model.state_dict()[key].numpy() for model in models],\n",
" axis=0, \n",
" weights=weights))\n",
"\n",
" new_model.load_state_dict(new_state_dict)\n",
" return new_model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define the model:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"class Net(nn.Module):\n",
" def __init__(self):\n",
" super(Net, self).__init__()\n",
" self.conv1 = nn.Conv2d(1, 16, 3)\n",
" self.pool = nn.MaxPool2d(2, 2)\n",
" self.conv2 = nn.Conv2d(16, 32, 3)\n",
" self.fc1 = nn.Linear(32 * 5 * 5, 32)\n",
" self.fc2 = nn.Linear(32, 84)\n",
" self.fc3 = nn.Linear(84, 10)\n",
"\n",
" def forward(self, x):\n",
" x = self.pool(F.relu(self.conv1(x)))\n",
" x = self.pool(F.relu(self.conv2(x)))\n",
" x = x.view(x.size(0),-1)\n",
" x = F.relu(self.fc1(x))\n",
" x = F.relu(self.fc2(x))\n",
" x = self.fc3(x)\n",
" return F.log_softmax(x, dim=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Set up the dataset:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"transform = transforms.Compose(\n",
" [transforms.ToTensor(),\n",
" transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
"\n",
"mnist_train = torchvision.datasets.MNIST(\n",
" \"./files/\",\n",
" train=True,\n",
" download=True,\n",
" transform=transform,\n",
")\n",
"\n",
"mnist_test = torchvision.datasets.MNIST(\n",
" \"./files/\",\n",
" train=False,\n",
" download=True,\n",
" transform=transform,\n",
")\n",
"\n",
"class CustomDataset(torch.utils.data.Dataset):\n",
" \"\"\"Dataset enumeration as tensors\"\"\"\n",
" def __init__(self, images, labels):\n",
" self.images = images\n",
" self.labels = labels\n",
"\n",
" def __len__(self):\n",
" return len(self.images)\n",
"\n",
" def __getitem__(self, idx):\n",
" image = self.images[idx]\n",
" label = self.labels[idx]\n",
" return image, label"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The next step is setting up the participants, an `Aggregator` and a few `Collaborator`s which will train the model, partition the dataset between the collaborators, and pass them to the appropriate runtime environment (in our case, a `LocalRuntime`).\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Setup participants\n",
"aggregator_ = Aggregator()\n",
"aggregator_.private_attributes = {}\n",
"\n",
"# Setup collaborators with private attributes\n",
"collaborator_names = [f'collaborator{i}' for i in range(4)]\n",
"collaborators = [Collaborator(name=name) for name in collaborator_names]\n",
"batch_size_train = 1024\n",
"batch_size_test = 1024\n",
"log_interval = 10\n",
"\n",
"for idx, collaborator_ in enumerate(collaborators):\n",
" train_images, train_labels = mnist_train.train_data, np.array(mnist_train.train_labels)\n",
" train_images = torch.from_numpy(np.expand_dims(train_images, axis=1)).float()\n",
" train_labels = one_hot(train_labels, 10)\n",
"\n",
" valid_images, valid_labels = mnist_test.test_data, np.array(mnist_test.test_labels)\n",
" valid_images = torch.from_numpy(np.expand_dims(valid_images, axis=1)).float()\n",
"\n",
" collaborator_.private_attributes = {\n",
" 'train_loader': torch.utils.data.DataLoader(\n",
" CustomDataset(train_images[idx::len(collaborators)], \n",
" train_labels[idx::len(collaborators)]), \n",
" batch_size=batch_size_train, \n",
" shuffle=True),\n",
" 'test_loader': torch.utils.data.DataLoader(\n",
" CustomDataset(valid_images[idx::len(collaborators)], \n",
" valid_labels[idx::len(collaborators)]), \n",
" batch_size=batch_size_test, \n",
" shuffle=True)\n",
" }\n",
"\n",
"local_runtime = LocalRuntime(aggregator=aggregator_, collaborators=collaborators, backend='single_process')\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Set up work to be executed by the aggregator and the collaborators by extending `FLSpec`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class FederatedFlow(FLSpec):\n",
" def __init__(self, model=None, optimizer=None, rounds=10, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.model = model\n",
" self.optimizer = optimizer\n",
" self.rounds = rounds\n",
" self.loss = 0.\n",
"\n",
" @aggregator\n",
" def start(self):\n",
" print(f'Performing initialization for model')\n",
" self.collaborators = self.runtime.collaborators\n",
" self.private = 10\n",
" self.current_round = 0\n",
" self.next(self.aggregated_model_validation, foreach='collaborators', exclude=['private'])\n",
"\n",
" def compute_accuracy(self, data_loader):\n",
" self.model.eval()\n",
" test_loss = 0\n",
" correct = 0\n",
" with torch.no_grad():\n",
" for data, target in data_loader:\n",
" output = self.model(data)\n",
" test_loss += F.cross_entropy(output, target, size_average=False).item()\n",
" pred = output.data.max(1, keepdim=True)[1]\n",
" correct += pred.eq(target.data.view_as(pred)).sum()\n",
"\n",
" test_loss /= len(data_loader.dataset)\n",
" print('\\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
" test_loss, correct, len(data_loader.dataset),\n",
" 100. * correct / len(data_loader.dataset)))\n",
" accuracy = float(correct / len(data_loader.dataset))\n",
" return accuracy\n",
"\n",
" @collaborator\n",
" def aggregated_model_validation(self):\n",
" print(f'Performing aggregated model validation for collaborator {self.input}, model: {id(self.model)}')\n",
" self.agg_validation_score = self.compute_accuracy(self.test_loader)\n",
" self.next(self.train)\n",
"\n",
" @collaborator\n",
" def train(self):\n",
" # Log after processing a quarter of the samples\n",
" log_threshold = .25\n",
"\n",
" self.model.train()\n",
" self.optimizer = get_optimizer(self.model)\n",
" for batch_idx, (data, target) in enumerate(self.train_loader):\n",
" self.optimizer.zero_grad()\n",
" output = self.model(data)\n",
" loss = F.cross_entropy(output, target)\n",
" loss.backward()\n",
" self.optimizer.step()\n",
"\n",
" if (len(data) * batch_idx) / len(self.train_loader.dataset) >= log_threshold:\n",
" print('Train Epoch: [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
" batch_idx * len(data), len(self.train_loader.dataset),\n",
" 100. * batch_idx / len(self.train_loader), loss.item()))\n",
" self.loss = loss.item()\n",
" log_threshold += .25\n",
" torch.save(self.model.state_dict(), 'model.pth')\n",
" torch.save(self.optimizer.state_dict(), 'optimizer.pth')\n",
" \n",
" self.next(self.local_model_validation)\n",
"\n",
" @collaborator\n",
" def local_model_validation(self):\n",
" print(f'Performing local model validation for collaborator {self.input}')\n",
" self.local_validation_score = self.compute_accuracy(self.test_loader)\n",
" print(\n",
" f'Done with local model validation for collaborator {self.input}, Accuracy: {self.local_validation_score}')\n",
" self.next(self.join)\n",
"\n",
" @aggregator\n",
" def join(self, inputs):\n",
" self.model = FedAvg([input.model for input in inputs])\n",
" self.optimizer = inputs[0].optimizer\n",
" self.current_round += 1\n",
"\n",
" self.average_loss = sum(input.loss for input in inputs) / len(inputs)\n",
" self.aggregated_model_accuracy = sum(\n",
" input.agg_validation_score for input in inputs) / len(inputs)\n",
" self.local_model_accuracy = sum(\n",
" input.local_validation_score for input in inputs) / len(inputs)\n",
" print(f'Average aggregated model accuracy = {self.aggregated_model_accuracy}')\n",
" print(f'Average training loss = {self.average_loss}')\n",
" print(f'Average local model validation values = {self.local_model_accuracy}')\n",
"\n",
" if self.current_round < self.rounds:\n",
" self.next(self.aggregated_model_validation,\n",
" foreach='collaborators', exclude=['private'])\n",
" else:\n",
" self.next(self.end)\n",
"\n",
" @aggregator\n",
" def end(self):\n",
" print(f'Flow ended')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, run the federation:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = Net()\n",
"flflow = FederatedFlow(model, get_optimizer(model), rounds=3, checkpoint=False)\n",
"flflow.runtime = local_runtime\n",
"flflow.run()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

0 comments on commit c73dd09

Please sign in to comment.