-
Notifications
You must be signed in to change notification settings - Fork 131
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
372 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,342 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Lecture 7: Convolutional networks\n", | ||
"\n", | ||
"Notebook adapted from [Deep Learning (with PyTorch)](https://github.com/Atcold/pytorch-Deep-Learning) by Alfredo Canziani. " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"import torch.nn.functional as F\n", | ||
"from torchvision import datasets, transforms\n", | ||
"\n", | ||
"import matplotlib.pyplot as plt" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", | ||
"device" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# MNIST" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"tf = transforms.Compose([transforms.ToTensor(),\n", | ||
" transforms.Normalize((0.1307,), (0.3081,))])\n", | ||
"\n", | ||
"train_loader = torch.utils.data.DataLoader(datasets.MNIST(\"./data\", train=True, transform=tf),\n", | ||
" batch_size=64, shuffle=True)\n", | ||
"\n", | ||
"test_loader = torch.utils.data.DataLoader(datasets.MNIST(\"./data\", train=False, transform=tf),\n", | ||
" batch_size=1000, shuffle=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"batch = next(iter(train_loader))\n", | ||
"x = batch[0][:10]\n", | ||
"y = batch[1][:10]\n", | ||
"\n", | ||
"fig, axs = plt.subplots(1, 5, figsize=(12, 4))\n", | ||
"\n", | ||
"for i in range(5):\n", | ||
" axs[i].imshow(x[i].squeeze().numpy())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# MLP vs ConvNet" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class MLP(nn.Module):\n", | ||
" def __init__(self, D, H, C):\n", | ||
" super().__init__()\n", | ||
" self.D = D\n", | ||
" self.net = nn.Sequential(\n", | ||
" nn.Linear(D, H),\n", | ||
" nn.ReLU(),\n", | ||
" nn.Linear(H, H),\n", | ||
" nn.ReLU(),\n", | ||
" nn.Linear(H, C),\n", | ||
" # nn.Softmax(dim=-1)\n", | ||
" )\n", | ||
" \n", | ||
" def forward(self, x):\n", | ||
" x = x.view(-1, self.D)\n", | ||
" return self.net(x)\n", | ||
"\n", | ||
"class ConvNet(nn.Module):\n", | ||
" def __init__(self, D, n_kernels, C):\n", | ||
" super().__init__()\n", | ||
" self.net = nn.Sequential(\n", | ||
" nn.Conv2d(in_channels=1, out_channels=n_kernels, kernel_size=5),\n", | ||
" nn.ReLU(),\n", | ||
" nn.MaxPool2d(kernel_size=2),\n", | ||
" nn.Conv2d(in_channels=n_kernels, out_channels=n_kernels, kernel_size=5),\n", | ||
" nn.ReLU(),\n", | ||
" nn.MaxPool2d(kernel_size=2),\n", | ||
" nn.Flatten(),\n", | ||
" nn.Linear(n_kernels * 4 * 4, 50),\n", | ||
" nn.ReLU(),\n", | ||
" nn.Linear(50, C),\n", | ||
" # nn.Softmax(dim=-1)\n", | ||
" )\n", | ||
" \n", | ||
" def forward(self, x):\n", | ||
" return self.net(x)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def train(model, perm=torch.arange(0, 784).long(), n_epochs=1):\n", | ||
" model.train() \n", | ||
" optimizer = torch.optim.AdamW(model.parameters())\n", | ||
" \n", | ||
" for epoch in range(n_epochs):\n", | ||
" for i, (data, target) in enumerate(train_loader):\n", | ||
" # send to device\n", | ||
" data, targets = data.to(device), target.to(device)\n", | ||
"\n", | ||
" # permute pixels\n", | ||
" data = data.view(-1, 28*28)\n", | ||
" data = data[:, perm]\n", | ||
" data = data.view(-1, 1, 28, 28)\n", | ||
"\n", | ||
" # step\n", | ||
" optimizer.zero_grad()\n", | ||
" logits = model(data)\n", | ||
" \n", | ||
" loss = F.cross_entropy(logits, targets)\n", | ||
" loss.backward()\n", | ||
" optimizer.step()\n", | ||
"\n", | ||
" if i % 100 == 0:\n", | ||
" print(f\"epoch={epoch}, step={i}: train loss={loss.item():.4f}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def test(model, perm=torch.arange(0, 784).long()):\n", | ||
" model.eval()\n", | ||
" \n", | ||
" test_loss = 0\n", | ||
" correct = 0\n", | ||
" \n", | ||
" for data, targets in test_loader:\n", | ||
" # send to device\n", | ||
" data, targets = data.to(device), targets.to(device)\n", | ||
" \n", | ||
" # permute pixels\n", | ||
" data = data.view(-1, 28*28)\n", | ||
" data = data[:, perm]\n", | ||
" data = data.view(-1, 1, 28, 28)\n", | ||
" \n", | ||
" # metrics\n", | ||
" logits = model(data)\n", | ||
" test_loss += F.cross_entropy(logits, targets, reduction='sum').item()\n", | ||
" preds = torch.argmax(logits, dim=1) \n", | ||
" correct += (preds == targets).sum()\n", | ||
"\n", | ||
" test_loss /= len(test_loader.dataset)\n", | ||
" accuracy = correct / len(test_loader.dataset)\n", | ||
" \n", | ||
" print(f\"test loss={test_loss:.4f}, accuracy={accuracy:.4f}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# MLP\n", | ||
"D = 28*28 \n", | ||
"C = 10 \n", | ||
"H = 8\n", | ||
"\n", | ||
"mlp = MLP(D, H, C)\n", | ||
"mlp.to(device)\n", | ||
"print(f\"Parameters={sum(p.numel() for p in mlp.parameters())/1e3}K\")\n", | ||
"\n", | ||
"train(mlp)\n", | ||
"test(mlp)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# ConvNet, with the same number of parameters\n", | ||
"n_kernels = 6\n", | ||
"\n", | ||
"convnet = ConvNet(D, n_kernels, C)\n", | ||
"convnet.to(device)\n", | ||
"print(f\"Parameters={sum(p.numel() for p in convnet.parameters())/1e3}K\")\n", | ||
"\n", | ||
"train(convnet)\n", | ||
"test(convnet)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The convolutional network performs better with the same number of parameters, thanks to its use of prior knowledge about images:\n", | ||
"\n", | ||
"* Use of convolution: Locality and stationarity in images\n", | ||
"* Pooling: builds in some translation invariance\n", | ||
"\n", | ||
"What if those assumptions are wrong?" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# MLP vs ConvNet, on shuffled pixels" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"perm = torch.randperm(784)\n", | ||
"\n", | ||
"batch = next(iter(train_loader))\n", | ||
"x = batch[0][:10]\n", | ||
"y = batch[1][:10]\n", | ||
"\n", | ||
"fig, axs = plt.subplots(1, 5, figsize=(12, 4))\n", | ||
"\n", | ||
"for i in range(5):\n", | ||
" axs[i].imshow(x[i].squeeze().numpy())\n", | ||
" \n", | ||
"fig, axs = plt.subplots(1, 5, figsize=(12, 4))\n", | ||
"x = x.view(-1, 28*28)\n", | ||
"x = x[:, perm]\n", | ||
"x = x.view(-1, 1, 28, 28)\n", | ||
"\n", | ||
"for i in range(5):\n", | ||
" axs[i].imshow(x[i].squeeze().numpy())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# ConvNet on shuffled pixels\n", | ||
"n_kernels = 6\n", | ||
"\n", | ||
"convnet = ConvNet(D, n_kernels, C)\n", | ||
"convnet.to(device)\n", | ||
"print(f\"Parameters={sum(p.numel() for p in convnet.parameters())/1e3}K\")\n", | ||
"\n", | ||
"train(convnet, perm=perm)\n", | ||
"test(convnet, perm=perm)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# MLP on shuffled pixels\n", | ||
"D = 28*28 \n", | ||
"C = 10 \n", | ||
"H = 8\n", | ||
"\n", | ||
"mlp = MLP(D, n_hidden, output_size)\n", | ||
"mlp.to(device)\n", | ||
"print(f\"Parameters={sum(p.numel() for p in mlp.parameters())/1e3}K\")\n", | ||
"\n", | ||
"train(mlp, perm=perm)\n", | ||
"test(mlp, perm=perm)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"The convolutional network's performance drops when we permute the pixels, but the MLP's performance stays the same.\n", | ||
"* ConvNet makes the assumption that pixels lie on a grid and are stationary/local.\n", | ||
"* It loses performance when this assumption is wrong.\n", | ||
"* The fully-connected network does not make this assumption.\n", | ||
"* It does less well when it is true, since it doesn't take advantage of this prior knowledge.\n", | ||
"* But it doesn't suffer when the assumption is wrong." | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "torch-cpu", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.9" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.