diff --git a/examples/legacy/Alpha_neuron_training_example.ipynb b/examples/legacy/Alpha_neuron_training_example.ipynb
deleted file mode 100644
index 56141d34..00000000
--- a/examples/legacy/Alpha_neuron_training_example.ipynb
+++ /dev/null
@@ -1,343 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "0rklDqluPYZB"
- },
- "source": [
- "## Example of training using the Alpha Neuron"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "ptd12vmR7H8D",
- "outputId": "79499811-cc47-4470-f190-f3a8581df051"
- },
- "outputs": [],
- "source": [
- "# !git clone -b alpha_neuron --single-branch https://github.com/jeshraghian/snntorch.git\n",
- "!pip install snntorch"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "5VNVo64L93Zo",
- "outputId": "69255447-8261-4119-f851-066c05ea510c"
- },
- "outputs": [],
- "source": [
- "# %cd snntorch\n",
- "import snntorch as snn"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "CEUipC6UPebM"
- },
- "source": [
- "## Import Packages"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "AQeg7QZ69fSt"
- },
- "outputs": [],
- "source": [
- "import snntorch as snn\n",
- "import torch\n",
- "import torch.nn as nn\n",
- "from torch.utils.data import DataLoader\n",
- "from torchvision import datasets, transforms\n",
- "import numpy as np\n",
- "import itertools\n",
- "import matplotlib.pyplot as plt"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "0OEdE6bEPf20"
- },
- "source": [
- "## Define Network and Parameters"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "MHmCvayV7LJR"
- },
- "outputs": [],
- "source": [
- "# test alpha neuron: can it learn?\n",
- "\n",
- "num_inputs = 28*28\n",
- "num_hidden = 1000\n",
- "num_outputs = 10\n",
- "\n",
- "# Training Parameters\n",
- "batch_size=128\n",
- "data_path='/tmp/data/mnist'\n",
- "\n",
- "# Temporal Dynamics\n",
- "num_steps = 25\n",
- "alpha = 0.9\n",
- "beta = 0.8\n",
- "\n",
- "dtype = torch.float\n",
- "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"mps\") if torch.backends.mps.is_available() else torch.device(\"cpu\")\n",
- "\n",
- "# Define Network\n",
- "class Net(nn.Module):\n",
- " def __init__(self):\n",
- " super().__init__()\n",
- " \n",
- " # initialize layers\n",
- " self.fc1 = nn.Linear(num_inputs, num_hidden)\n",
- " self.lif1 = snn.Alpha(alpha=alpha, beta=beta)\n",
- " self.fc2 = nn.Linear(num_hidden, num_outputs)\n",
- " self.lif2 = snn.Alpha(alpha=alpha, beta=beta)\n",
- "\n",
- "\n",
- " def forward(self, x):\n",
- " spk1, syn_exc1, syn_inh1, mem1 = self.lif1.init_alpha(batch_size, num_hidden)\n",
- " spk2, syn_exc2, syn_inh2, mem2 = self.lif2.init_alpha(batch_size, num_outputs)\n",
- "\n",
- " # Record the final layer\n",
- " spk2_rec = []\n",
- " mem2_rec = []\n",
- "\n",
- " for step in range(num_steps):\n",
- "\n",
- " cur1 = self.fc1(x)\n",
- " spk1, syn_exc1, syn_inh1, mem1 = self.lif1(cur1, syn_exc1, syn_inh1, mem1)\n",
- " cur2 = self.fc2(spk1)\n",
- " spk2, syn_exc2, syn_inh2, mem2 = self.lif2(cur2, syn_exc2, syn_inh2, mem2)\n",
- "\n",
- " spk2_rec.append(spk2)\n",
- " mem2_rec.append(mem2)\n",
- "\n",
- " return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)\n",
- " \n",
- "net = Net().to(device)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "2N00-2eDPl2G"
- },
- "source": [
- "## dataloaders"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/"
- },
- "id": "NblLifHj9qO-",
- "outputId": "e890995e-a9fd-490d-f278-caf6bce9c21e"
- },
- "outputs": [],
- "source": [
- "# Define a transform\n",
- "transform = transforms.Compose([\n",
- " transforms.Resize((28, 28)),\n",
- " transforms.Grayscale(),\n",
- " transforms.ToTensor(),\n",
- " transforms.Normalize((0,), (1,))])\n",
- "\n",
- "mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)\n",
- "mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "kx8nATF69tEk"
- },
- "outputs": [],
- "source": [
- "# Create DataLoaders\n",
- "train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)\n",
- "test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "gqDmeDJcP_HL"
- },
- "source": [
- "## Print Accuracy Function"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "EJYEQpHk-MzX"
- },
- "outputs": [],
- "source": [
- "def print_batch_accuracy(data, targets, train=False):\n",
- " with torch.no_grad():\n",
- " output, _ = net(data.view(batch_size, -1))\n",
- " _, idx = output.sum(dim=0).max(1)\n",
- " acc = np.mean((targets == idx).detach().cpu().numpy())\n",
- "\n",
- " if train:\n",
- " print(f\"Train Set Accuracy: {acc}\")\n",
- " else:\n",
- " print(f\"Test Set Accuracy: {acc}\")\n",
- "\n",
- "def train_printer():\n",
- " print(f\"Epoch {epoch}, Minibatch {minibatch_counter}\")\n",
- " print(f\"Train Set Loss: {loss_hist[counter]}\")\n",
- " print(f\"Test Set Loss: {test_loss_hist[counter]}\")\n",
- " print_batch_accuracy(data_it, targets_it, train=True)\n",
- " print_batch_accuracy(testdata_it, testtargets_it, train=False)\n",
- " print(\"\\n\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "m2Zz5xesQBBk"
- },
- "source": [
- "## Define Loss & Optimizer"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "4Uwvn33m-OuY"
- },
- "outputs": [],
- "source": [
- "optimizer = torch.optim.Adam(net.parameters(), lr=2e-4, betas=(0.9, 0.999))\n",
- "log_softmax_fn = nn.LogSoftmax(dim=-1)\n",
- "loss_fn = nn.NLLLoss()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "wY-94aYrQCXE"
- },
- "source": [
- "## Training Loop"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 1000
- },
- "id": "hQvfXduS-QmB",
- "outputId": "7ec1f1dc-ce3f-4834-f601-9b5c34245225"
- },
- "outputs": [],
- "source": [
- "loss_hist = []\n",
- "test_loss_hist = []\n",
- "counter = 0\n",
- "\n",
- "# Outer training loop\n",
- "for epoch in range(10):\n",
- " minibatch_counter = 0\n",
- " train_batch = iter(train_loader)\n",
- "\n",
- " # Minibatch training loop\n",
- " for data_it, targets_it in train_batch:\n",
- " data_it = data_it.to(device)\n",
- " targets_it = targets_it.to(device)\n",
- "\n",
- " spk_rec, mem_rec = net(data_it.view(batch_size, -1))\n",
- " log_p_y = log_softmax_fn(mem_rec)\n",
- " loss_val = torch.zeros((1), dtype=dtype, device=device)\n",
- "\n",
- " # Sum loss over time steps: BPTT\n",
- " for step in range(num_steps):\n",
- " loss_val += loss_fn(log_p_y[step], targets_it)\n",
- "\n",
- " # Gradient calculation\n",
- " optimizer.zero_grad()\n",
- " loss_val.backward()\n",
- "\n",
- " # Weight Update\n",
- " optimizer.step()\n",
- "\n",
- " # Store loss history for future plotting\n",
- " loss_hist.append(loss_val.item())\n",
- "\n",
- " # Test set\n",
- " test_data = itertools.cycle(test_loader)\n",
- " testdata_it, testtargets_it = next(test_data)\n",
- " testdata_it = testdata_it.to(device)\n",
- " testtargets_it = testtargets_it.to(device)\n",
- "\n",
- " # Test set forward pass\n",
- " with torch.no_grad():\n",
- " test_spk, test_mem = net(testdata_it.view(batch_size, -1))\n",
- "\n",
- " # Test set loss\n",
- " log_p_ytest = log_softmax_fn(test_mem)\n",
- " log_p_ytest = log_p_ytest.sum(dim=0)\n",
- " loss_val_test = loss_fn(log_p_ytest, testtargets_it)\n",
- " test_loss_hist.append(loss_val_test.item())\n",
- "\n",
- " # Print test/train loss/accuracy\n",
- " if counter % 50 == 0:\n",
- " train_printer()\n",
- " minibatch_counter += 1\n",
- " counter += 1\n",
- "\n",
- "loss_hist_true_grad = loss_hist\n",
- "test_loss_hist_true_grad = test_loss_hist"
- ]
- }
- ],
- "metadata": {
- "accelerator": "GPU",
- "colab": {
- "collapsed_sections": [],
- "name": "Alpha_code_example.ipynb",
- "provenance": []
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- },
- "language_info": {
- "name": "python"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
-}
diff --git a/examples/legacy/CIFAR_temp.ipynb b/examples/legacy/CIFAR_temp.ipynb
deleted file mode 100644
index e86b4176..00000000
--- a/examples/legacy/CIFAR_temp.ipynb
+++ /dev/null
@@ -1,2857 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "collapsed": false,
- "id": "rtrNT4NPRp7r",
- "pycharm": {
- "name": "#%% md\n"
- }
- },
- "source": [
- "\n",
- "\n",
- "# snnTorch - Gradient-based Learning in Spiking Neural Networks\n",
- "## Tutorial 2\n",
- "### By Jason K. Eshraghian\n",
- "\n",
- "\n",
- " \n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "collapsed": false,
- "id": "TKy-qDQdRp73"
- },
- "source": [
- "# Introduction\n",
- "In this tutorial, you will learn how to use snnTorch to:\n",
- "* create a 2-layer fully-connected spiking network;\n",
- "* implement the backpropagation through time (BPTT) algorithm;\n",
- "* to classify both the static and spiking MNIST datasets.\n",
- "\n",
- "If running in Google Colab:\n",
- "* You may connect to GPU by checking `Runtime` > `Change runtime type` > `Hardware accelerator: GPU`\n",
- "* Next, install the latest PyPi distribution of snnTorch by clicking into the following cell and pressing `Shift+Enter`."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 0
- },
- "id": "BgBRVUtpRp74",
- "outputId": "61c32ebb-e69b-4d44-852b-0bf58481a9d1",
- "pycharm": {
- "name": "#%%\n"
- }
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Requirement already satisfied: snntorch in /usr/local/lib/python3.7/dist-packages (0.2.7)\n",
- "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from snntorch) (1.1.5)\n",
- "Requirement already satisfied: torch>=1.2.0 in /usr/local/lib/python3.7/dist-packages (from snntorch) (1.8.0+cu101)\n",
- "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from snntorch) (3.2.2)\n",
- "Requirement already satisfied: celluloid in /usr/local/lib/python3.7/dist-packages (from snntorch) (0.2.0)\n",
- "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.7/dist-packages (from snntorch) (1.19.5)\n",
- "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->snntorch) (2018.9)\n",
- "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->snntorch) (2.8.1)\n",
- "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.2.0->snntorch) (3.7.4.3)\n",
- "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->snntorch) (1.3.1)\n",
- "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->snntorch) (0.10.0)\n",
- "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->snntorch) (2.4.7)\n",
- "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->snntorch) (1.15.0)\n"
- ]
- }
- ],
- "source": [
- "!pip install snntorch"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "collapsed": false,
- "id": "Zm-D2lthRp75",
- "pycharm": {
- "name": "#%% md\n"
- }
- },
- "source": [
- "## 1. Setting up the Static MNIST Dataset\n",
- "### 1.1. Import packages and setup environment"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {
- "id": "sEygpdc8Rp76",
- "pycharm": {
- "name": "#%%\n"
- }
- },
- "outputs": [],
- "source": [
- "import snntorch as snn\n",
- "import torch\n",
- "import torch.nn as nn\n",
- "from torch.utils.data import DataLoader\n",
- "from torchvision import datasets, transforms\n",
- "import numpy as np\n",
- "import itertools\n",
- "import matplotlib.pyplot as plt"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "collapsed": false,
- "id": "RcX7J9vVRp76"
- },
- "source": [
- "### 1.2 Define network and SNN parameters\n",
- "We will use a 784-1000-10 FCN architecture for a sequence of 25 time steps.\n",
- "\n",
- "* `alpha` is the decay rate of the synaptic current of a neuron\n",
- "* `beta` is the decay rate of the membrane potential of a neuron"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {
- "id": "bKEj2hucRp77",
- "pycharm": {
- "name": "#%%\n"
- }
- },
- "outputs": [],
- "source": [
- "# Network Architecture\n",
- "num_inputs = 32*32\n",
- "num_hidden = 1000\n",
- "num_outputs = 10\n",
- "\n",
- "# Training Parameters\n",
- "batch_size=128\n",
- "data_path='/tmp/data/mnist'\n",
- "\n",
- "# Temporal Dynamics\n",
- "num_steps = 25\n",
- "time_step = 1e-3\n",
- "tau_mem = 3e-3\n",
- "tau_syn = 2.2e-3\n",
- "alpha = float(np.exp(-time_step/tau_syn))\n",
- "beta = float(np.exp(-time_step/tau_mem))\n",
- "\n",
- "dtype = torch.float\n",
- "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device(\"cpu\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {
- "id": "agBKaYiIQxbS"
- },
- "outputs": [],
- "source": [
- "alpha = 0.8\n",
- "beta = 0.9"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "collapsed": false,
- "id": "gM_hcwDIRp78",
- "pycharm": {
- "name": "#%% md\n"
- }
- },
- "source": [
- "### 1.3 Download MNIST Dataset\n",
- "To see how to construct a validation set, refer to Tutorial 1."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 120,
- "referenced_widgets": [
- "feb965fc6abb4ee9bc9dbb5ef3b9723c",
- "1a99d740897c4ea09c6f33ce3a4a3230",
- "82d6d6dfb0c845ec99c25f57e6c5d9d2",
- "c03ab1b2c81149feb0c1551831a6a909",
- "f74b5621630e458aa55562f4ae2f36f8",
- "f999b8f0ade1483c83e5c2f047bad551",
- "e2f3ff56de4e424bb8e2f09a9976bf5b",
- "fe66ac83ad074efbbc802309b16e08a3"
- ]
- },
- "id": "0xwYb15xRp79",
- "outputId": "fd8cb868-a8c3-45a1-a43a-4aad7992884a",
- "pycharm": {
- "name": "#%%\n"
- }
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /data/mnist/cifar-10-python.tar.gz\n"
- ]
- },
- {
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "feb965fc6abb4ee9bc9dbb5ef3b9723c",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))"
- ]
- },
- "metadata": {
- "tags": []
- },
- "output_type": "display_data"
- },
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\n",
- "Extracting /data/mnist/cifar-10-python.tar.gz to /data/mnist\n",
- "Files already downloaded and verified\n"
- ]
- }
- ],
- "source": [
- "# Define a transform\n",
- "transform = transforms.Compose([\n",
- " transforms.Grayscale(),\n",
- " transforms.Resize((32, 32)),\n",
- " transforms.ToTensor(),\n",
- " transforms.Normalize((0,), (1,))])\n",
- "\n",
- "mnist_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform)\n",
- "mnist_test = datasets.CIFAR10(data_path, train=False, download=True, transform=transform)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "collapsed": false,
- "id": "DIg3vLmURp7-",
- "pycharm": {
- "name": "#%% md\n"
- }
- },
- "source": [
- "### 1.4 Create DataLoaders"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {
- "id": "K4DF-odMRp7_",
- "pycharm": {
- "name": "#%%\n"
- }
- },
- "outputs": [],
- "source": [
- "train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)\n",
- "test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "collapsed": false,
- "id": "pD4Dw-RoRp7_",
- "pycharm": {
- "name": "#%% md\n"
- }
- },
- "source": [
- "## 2. Define Network\n",
- "snnTorch treats neurons as activations with recurrent connections. This allows for smooth integration with PyTorch.\n",
- "There are a few useful neuron models and surrogate gradient functions which approximate the gradient of spikes.\n",
- "\n",
- "Our network will use one type of neuron model and one surrogate gradient:\n",
- "1. `snntorch.Stein` is a basic leaky integrate and fire (LIF) neuron. Specifically, it assumes instantaneous rise times for synaptic current and membrane potential.\n",
- "2. `snntorch.FastSigmoidSurrogate` defines separate forward and backward functions. The forward function is a Heaviside step function for spike generation. The backward function is the derivative of a fast sigmoid function, to ensure continuous differentiability.\n",
- "The `FastSigmoidSurrogate` function has been adapted from:\n",
- "\n",
- ">Neftci, E. O., Mostafa, H., and Zenke, F. (2019) Surrogate Gradient Learning in Spiking Neural Networks. https://arxiv.org/abs/1901/09948"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {
- "id": "Sf9RdE9jRp8A",
- "pycharm": {
- "name": "#%%\n"
- }
- },
- "outputs": [],
- "source": [
- "# from snntorch import surrogate\n",
- "#\n",
- "# spike_grad = surrogate.FastSigmoid.apply\n",
- "# snn.slope = 50"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "collapsed": false,
- "id": "Og9e57W0Rp8B",
- "pycharm": {
- "name": "#%% md\n"
- }
- },
- "source": [
- "The surrogate is passed to `spike_grad` and overrides the default gradient of the Heaviside step function.\n",
- "If we did not override the default gradient, (zero everywhere, except for $x=1$ where it is technically infinite but clipped to 1 here), then learning would not take place for as long as the neuron was not emitting post-synaptic spikes.\n",
- "\n",
- "`snn.slope` defines the slope of the backward surrogate.\n",
- "\n",
- "TO-DO: Include visualisation."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "collapsed": false,
- "id": "Yo-x48mARp8C",
- "pycharm": {
- "name": "#%% md\n"
- }
- },
- "source": [
- "Now we can define our spiking neural network (SNN).\n",
- "Creating an instance of the `Stein` neuron requires two arguments and two optional arguments:\n",
- "1. $I_{syn}$ decay rate, $\\alpha$,\n",
- "2. $V_{mem}$ decay rate, $\\beta$,\n",
- "3. the surrogate spiking function, `spike_grad` (*default*: the gradient of the Heaviside function), and\n",
- "4. the threshold for spiking, (*default*: 1.0).\n",
- "\n",
- "snnTorch treats the LIF neuron as a recurrent activation. Therefore, it requires initialization of its internal states.\n",
- "For each layer, we initialize the synaptic current `syn1` and `syn2`, the membrane potential `mem1` and `mem2`, and the post-synaptic spikes `spk1` and `spk2` to zero.\n",
- "A class method `init_stein` will take care of this.\n",
- "\n",
- "For rate coding, the final layer of spikes and membrane potential are used to determine accuracy and loss, respectively.\n",
- "So their historical values are recorded in `spk2_rec` and `mem2_rec`.\n",
- "\n",
- "Keep in mind, the dataset we are using is just static MNIST. I.e., it is *not* time-varying.\n",
- "Therefore, we pass the same MNIST sample to the input at each time step.\n",
- "This is handled in the line `cur1 = self.fc1(x)`, where `x` is the same input over the whole for-loop."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {
- "id": "P6RHCnXMRp8D",
- "pycharm": {
- "name": "#%%\n"
- }
- },
- "outputs": [],
- "source": [
- "# Define Network\n",
- "class Net(nn.Module):\n",
- " def __init__(self):\n",
- " super().__init__()\n",
- "\n",
- " # initialize layers\n",
- " self.fc1 = nn.Linear(num_inputs, num_hidden)\n",
- " self.lif1 = snn.Stein(alpha=alpha, beta=beta)\n",
- " self.fc2 = nn.Linear(num_hidden, num_outputs)\n",
- " self.lif2 = snn.Stein(alpha=alpha, beta=beta)\n",
- "\n",
- " def forward(self, x):\n",
- " spk1, syn1, mem1 = self.lif1.init_stein(batch_size, num_hidden)\n",
- " spk2, syn2, mem2 = self.lif2.init_stein(batch_size, num_outputs)\n",
- "\n",
- " spk2_rec = []\n",
- " mem2_rec = []\n",
- "\n",
- " for step in range(num_steps):\n",
- " cur1 = self.fc1(x)\n",
- " spk1, syn1, mem1 = self.lif1(cur1, syn1, mem1)\n",
- " cur2 = self.fc2(spk1)\n",
- " spk2, syn2, mem2 = self.lif2(cur2, syn2, mem2)\n",
- "\n",
- " spk2_rec.append(spk2)\n",
- " mem2_rec.append(mem2)\n",
- "\n",
- " return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)\n",
- "\n",
- "net = Net().to(device)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "collapsed": false,
- "id": "386KNHG7Rp8E",
- "pycharm": {
- "name": "#%% md\n"
- }
- },
- "source": [
- "## 3. Training\n",
- "Time for training! Let's first define a couple of functions to print out test/train accuracy."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 19,
- "metadata": {
- "id": "cOKKbUnDRp8F",
- "pycharm": {
- "name": "#%%\n"
- }
- },
- "outputs": [],
- "source": [
- "def print_batch_accuracy(data, targets, train=False):\n",
- " output, _ = net(data.view(batch_size, -1))\n",
- " _, idx = output.sum(dim=0).max(1)\n",
- " acc = np.mean((targets == idx).detach().cpu().numpy())\n",
- "\n",
- " if train:\n",
- " print(f\"Train Set Accuracy: {acc}\")\n",
- " else:\n",
- " print(f\"Test Set Accuracy: {acc}\")\n",
- "\n",
- "def train_printer():\n",
- " print(f\"Epoch {epoch}, Minibatch {minibatch_counter}\")\n",
- " print(f\"Train Set Loss: {loss_hist[counter]}\")\n",
- " print(f\"Test Set Loss: {test_loss_hist[counter]}\")\n",
- " print_batch_accuracy(data_it, targets_it, train=True)\n",
- " print_batch_accuracy(testdata_it, testtargets_it, train=False)\n",
- " print(\"\\n\")"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "collapsed": false,
- "id": "mZqfCe0KRp8J",
- "pycharm": {
- "name": "#%% md\n"
- }
- },
- "source": [
- "### 3.1 Optimizer & Loss\n",
- "* *Output Activation*: We'll apply the softmax function to the membrane potentials of the output layer, rather than the spikes.\n",
- "* *Loss*: This will then be used to calculate the negative log-likelihood loss.\n",
- "By encouraging the membrane of the correct neuron class to reach the threshold, we expect that neuron will fire more frequently.\n",
- "The loss could be applied to the spike count as well, but the membrane is continuous whereas spike count is discrete.\n",
- "* *Optimizer*: The Adam optimizer is used for weight updates.\n",
- "* *Accuracy*: Accuracy is measured by counting the spikes of the output neurons. The neuron that fires the most frequently will be our predicted class."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 20,
- "metadata": {
- "id": "UMZ-4uGlRp8K",
- "pycharm": {
- "name": "#%%\n"
- }
- },
- "outputs": [],
- "source": [
- "optimizer = torch.optim.Adam(net.parameters(), lr=2e-4, betas=(0.9, 0.999))\n",
- "log_softmax_fn = nn.LogSoftmax(dim=-1)\n",
- "loss_fn = nn.NLLLoss()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "collapsed": false,
- "id": "UerqLkd1Rp8K",
- "pycharm": {
- "name": "#%% md\n"
- }
- },
- "source": [
- "### 3.2 Training Loop\n",
- "Now just sit back, relax, and wait for convergence."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 21,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 1000
- },
- "id": "76CEHI2xRp8L",
- "outputId": "e29c7e0e-4504-4d08-fdcf-70ab371b0996",
- "pycharm": {
- "name": "#%%\n"
- }
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Epoch 2, Minibatch 70\n",
- "Train Set Loss: 53.006282806396484\n",
- "Test Set Loss: 55.81739807128906\n",
- "Train Set Accuracy: 0.2265625\n",
- "Test Set Accuracy: 0.21875\n",
- "\n",
- "\n",
- "Epoch 2, Minibatch 120\n",
- "Train Set Loss: 50.85497283935547\n",
- "Test Set Loss: 50.894222259521484\n",
- "Train Set Accuracy: 0.265625\n",
- "Test Set Accuracy: 0.3125\n",
- "\n",
- "\n",
- "Epoch 2, Minibatch 170\n",
- "Train Set Loss: 49.45563507080078\n",
- "Test Set Loss: 50.29939651489258\n",
- "Train Set Accuracy: 0.3671875\n",
- "Test Set Accuracy: 0.2578125\n",
- "\n",
- "\n",
- "Epoch 2, Minibatch 220\n",
- "Train Set Loss: 52.987083435058594\n",
- "Test Set Loss: 51.679054260253906\n",
- "Train Set Accuracy: 0.2109375\n",
- "Test Set Accuracy: 0.234375\n",
- "\n",
- "\n",
- "Epoch 2, Minibatch 270\n",
- "Train Set Loss: 49.68134307861328\n",
- "Test Set Loss: 52.56451416015625\n",
- "Train Set Accuracy: 0.3359375\n",
- "Test Set Accuracy: 0.234375\n",
- "\n",
- "\n",
- "Epoch 2, Minibatch 320\n",
- "Train Set Loss: 54.24823760986328\n",
- "Test Set Loss: 51.5257682800293\n",
- "Train Set Accuracy: 0.25\n",
- "Test Set Accuracy: 0.25\n",
- "\n",
- "\n",
- "Epoch 2, Minibatch 370\n",
- "Train Set Loss: 50.897029876708984\n",
- "Test Set Loss: 49.811622619628906\n",
- "Train Set Accuracy: 0.2890625\n",
- "Test Set Accuracy: 0.34375\n",
- "\n",
- "\n"
- ]
- },
- {
- "ename": "KeyboardInterrupt",
- "evalue": "ignored",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;31m# Gradient calculation\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0mloss_val\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mretain_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 26\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[0;31m# Weight Update\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 243\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 244\u001b[0m inputs=inputs)\n\u001b[0;32m--> 245\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 246\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 247\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 145\u001b[0m Variable._execution_engine.run_backward(\n\u001b[1;32m 146\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 147\u001b[0;31m allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag\n\u001b[0m\u001b[1;32m 148\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 149\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
- ]
- }
- ],
- "source": [
- "loss_hist = []\n",
- "test_loss_hist = []\n",
- "counter = 0\n",
- "\n",
- "# Outer training loop\n",
- "for epoch in range(3):\n",
- " minibatch_counter = 0\n",
- " train_batch = iter(train_loader)\n",
- "\n",
- " # Minibatch training loop\n",
- " for data_it, targets_it in train_batch:\n",
- " data_it = data_it.to(device)\n",
- " targets_it = targets_it.to(device)\n",
- "\n",
- " output, mem_rec = net(data_it.view(batch_size, -1))\n",
- " log_p_y = log_softmax_fn(mem_rec)\n",
- " loss_val = torch.zeros((1), dtype=dtype, device=device)\n",
- "\n",
- " # Sum loss over time steps: BPTT\n",
- " for step in range(num_steps):\n",
- " loss_val += loss_fn(log_p_y[step], targets_it)\n",
- "\n",
- " # Gradient calculation\n",
- " optimizer.zero_grad()\n",
- " loss_val.backward(retain_graph=True)\n",
- "\n",
- " # Weight Update\n",
- " nn.utils.clip_grad_norm_(net.parameters(), 1) # gradient clipping\n",
- " optimizer.step()\n",
- "\n",
- " # Store loss history for future plotting\n",
- " loss_hist.append(loss_val.item())\n",
- "\n",
- " # Test set\n",
- " test_data = itertools.cycle(test_loader)\n",
- " testdata_it, testtargets_it = next(test_data)\n",
- " testdata_it = testdata_it.to(device)\n",
- " testtargets_it = testtargets_it.to(device)\n",
- "\n",
- " # Test set forward pass\n",
- " test_output, test_mem_rec = net(testdata_it.view(batch_size, -1))\n",
- "\n",
- " # Test set loss\n",
- " log_p_ytest = log_softmax_fn(test_mem_rec)\n",
- " log_p_ytest = log_p_ytest.sum(dim=0)\n",
- " loss_val_test = loss_fn(log_p_ytest, testtargets_it)\n",
- " test_loss_hist.append(loss_val_test.item())\n",
- "\n",
- " # Print test/train loss/accuracy\n",
- " if counter % 50 == 0:\n",
- " train_printer()\n",
- " minibatch_counter += 1\n",
- " counter += 1\n",
- "\n",
- "loss_hist_true_grad = loss_hist\n",
- "test_loss_hist_true_grad = test_loss_hist"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "collapsed": false,
- "id": "fdOuwCWHRp8L",
- "pycharm": {
- "name": "#%% md\n"
- }
- },
- "source": [
- "## 4. Results\n",
- "### 4.1 Plot Training/Test Loss"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {
- "colab": {
- "base_uri": "https://localhost:8080/",
- "height": 334
- },
- "id": "WJGSBq6zRp8M",
- "outputId": "9c224143-2579-4708-88a8-e645e32ce289",
- "pycharm": {
- "name": "#%%\n"
- }
- },
- "outputs": [
- {
- "data": {
- "image/png": "",
- "text/plain": [
- "