diff --git a/docs/_static/img/examples/regression2/reg2-1.jpg b/docs/_static/img/examples/regression2/reg2-1.jpg deleted file mode 100644 index d1eca330..00000000 Binary files a/docs/_static/img/examples/regression2/reg2-1.jpg and /dev/null differ diff --git a/docs/_static/img/examples/regression2/reg2-2.jpg b/docs/_static/img/examples/regression2/reg2-2.jpg deleted file mode 100644 index b3868422..00000000 Binary files a/docs/_static/img/examples/regression2/reg2-2.jpg and /dev/null differ diff --git a/docs/_static/img/examples/tutorial1/1_2_1_static.png b/docs/_static/img/examples/tutorial1/1_2_1_static.png deleted file mode 100644 index bf820e9c..00000000 Binary files a/docs/_static/img/examples/tutorial1/1_2_1_static.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial1/1_2_2_spikeinput.png b/docs/_static/img/examples/tutorial1/1_2_2_spikeinput.png deleted file mode 100644 index 49b44889..00000000 Binary files a/docs/_static/img/examples/tutorial1/1_2_2_spikeinput.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial1/1_2_3_spikeconv.png b/docs/_static/img/examples/tutorial1/1_2_3_spikeconv.png deleted file mode 100644 index 1c23fc78..00000000 Binary files a/docs/_static/img/examples/tutorial1/1_2_3_spikeconv.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial1/1_2_4_latencyrc.png b/docs/_static/img/examples/tutorial1/1_2_4_latencyrc.png deleted file mode 100644 index 1806b95e..00000000 Binary files a/docs/_static/img/examples/tutorial1/1_2_4_latencyrc.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial1/1_2_5_latencyraster.png b/docs/_static/img/examples/tutorial1/1_2_5_latencyraster.png deleted file mode 100644 index 8dbcb4d0..00000000 Binary files a/docs/_static/img/examples/tutorial1/1_2_5_latencyraster.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial1/1_2_6_latencylinear.png b/docs/_static/img/examples/tutorial1/1_2_6_latencylinear.png deleted file mode 100644 index 70fd2d78..00000000 Binary files a/docs/_static/img/examples/tutorial1/1_2_6_latencylinear.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial1/1_2_7_delta.png b/docs/_static/img/examples/tutorial1/1_2_7_delta.png deleted file mode 100644 index 8034edac..00000000 Binary files a/docs/_static/img/examples/tutorial1/1_2_7_delta.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial1/3s.png b/docs/_static/img/examples/tutorial1/3s.png deleted file mode 100644 index 8c6679b2..00000000 Binary files a/docs/_static/img/examples/tutorial1/3s.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial1/fig1_p2.png b/docs/_static/img/examples/tutorial1/fig1_p2.png deleted file mode 100644 index 4c0a5c26..00000000 Binary files a/docs/_static/img/examples/tutorial1/fig1_p2.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial3/3_1_recurrent.png b/docs/_static/img/examples/tutorial3/3_1_recurrent.png deleted file mode 100644 index 3ebcc030..00000000 Binary files a/docs/_static/img/examples/tutorial3/3_1_recurrent.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial3/3_1_stein_decomp.png b/docs/_static/img/examples/tutorial3/3_1_stein_decomp.png deleted file mode 100644 index dcdc8a1c..00000000 Binary files a/docs/_static/img/examples/tutorial3/3_1_stein_decomp.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial3/3_2_spike_descrip.png b/docs/_static/img/examples/tutorial3/3_2_spike_descrip.png deleted file mode 100644 index 19d78ec4..00000000 Binary files a/docs/_static/img/examples/tutorial3/3_2_spike_descrip.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial3/3_2_unrolled.png b/docs/_static/img/examples/tutorial3/3_2_unrolled.png deleted file mode 100644 index 8e2c72b3..00000000 Binary files a/docs/_static/img/examples/tutorial3/3_2_unrolled.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial3/3_3_delay.png b/docs/_static/img/examples/tutorial3/3_3_delay.png deleted file mode 100644 index 5e1124d7..00000000 Binary files a/docs/_static/img/examples/tutorial3/3_3_delay.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial3/3_4_graphdelay.png b/docs/_static/img/examples/tutorial3/3_4_graphdelay.png deleted file mode 100644 index 37b025a5..00000000 Binary files a/docs/_static/img/examples/tutorial3/3_4_graphdelay.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial3/3_5_targets.png b/docs/_static/img/examples/tutorial3/3_5_targets.png deleted file mode 100644 index 30b97fac..00000000 Binary files a/docs/_static/img/examples/tutorial3/3_5_targets.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial3/3_6_bptt.png b/docs/_static/img/examples/tutorial3/3_6_bptt.png deleted file mode 100644 index 4b98c52c..00000000 Binary files a/docs/_static/img/examples/tutorial3/3_6_bptt.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial3/3_7_stein_bptt.png b/docs/_static/img/examples/tutorial3/3_7_stein_bptt.png deleted file mode 100644 index 80c6c707..00000000 Binary files a/docs/_static/img/examples/tutorial3/3_7_stein_bptt.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial3/3_8_timevarying.png b/docs/_static/img/examples/tutorial3/3_8_timevarying.png deleted file mode 100644 index f5f13cf1..00000000 Binary files a/docs/_static/img/examples/tutorial3/3_8_timevarying.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial3/3_9_spike_grad.png b/docs/_static/img/examples/tutorial3/3_9_spike_grad.png deleted file mode 100644 index 72453b7a..00000000 Binary files a/docs/_static/img/examples/tutorial3/3_9_spike_grad.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial5/bptt.png b/docs/_static/img/examples/tutorial5/bptt.png deleted file mode 100644 index 28cb70ca..00000000 Binary files a/docs/_static/img/examples/tutorial5/bptt.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial5/loss.png b/docs/_static/img/examples/tutorial5/loss.png deleted file mode 100644 index 1f3bb05d..00000000 Binary files a/docs/_static/img/examples/tutorial5/loss.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial5/non-diff.png b/docs/_static/img/examples/tutorial5/non-diff.png deleted file mode 100644 index 6468c2c0..00000000 Binary files a/docs/_static/img/examples/tutorial5/non-diff.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial5/unrolled.png b/docs/_static/img/examples/tutorial5/unrolled.png deleted file mode 100644 index 0495ac84..00000000 Binary files a/docs/_static/img/examples/tutorial5/unrolled.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial5/unrolled_2.png b/docs/_static/img/examples/tutorial5/unrolled_2.png deleted file mode 100644 index b7627680..00000000 Binary files a/docs/_static/img/examples/tutorial5/unrolled_2.png and /dev/null differ diff --git a/docs/_static/img/examples/tutorial_pop/pop.png b/docs/_static/img/examples/tutorial_pop/pop.png deleted file mode 100644 index 4803e18f..00000000 Binary files a/docs/_static/img/examples/tutorial_pop/pop.png and /dev/null differ diff --git a/docs/tutorials/tutorials.rst b/docs/tutorials/tutorials.rst index ee7d5dc8..41425396 100644 --- a/docs/tutorials/tutorials.rst +++ b/docs/tutorials/tutorials.rst @@ -78,16 +78,35 @@ The tutorial consists of a series of Google Colab notebooks. Static non-editable :alt: Open In Colab :target: https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_regression_2.ipynb - * - `Example: Exoplanet_Hunter `_ + * - `Binarized Spiking Neural Networks: Erik Mercado `_ - .. image:: https://colab.research.google.com/assets/colab-badge.svg + :alt: Open In Colab + :target: https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_BSNN.ipynb + * - `Accelerating snnTorch on IPUs `_ - — - * - `Binarized Spiking Neural Networks: Erik Mercado `_ + + + + +.. list-table:: + :widths: 70 32 + :header-rows: 1 + + * - Dataset Tutorials + - Colab Link + + + * - `Exoplanet Hunter: Finding Planets Using Light Intensity `_ - .. image:: https://colab.research.google.com/assets/colab-badge.svg + :alt: Open In Colab + :target: https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_exoplanet_hunter.ipynb + + Future tutorials on spiking neurons and training are under development. \ No newline at end of file diff --git a/setup.py b/setup.py index e51e58b9..e78d41dd 100644 --- a/setup.py +++ b/setup.py @@ -46,6 +46,7 @@ "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "Operating System :: OS Independent", ], description="Deep learning with spiking neural networks.", diff --git a/snntorch/_neurons/leakyunroll.py b/snntorch/_neurons/leakyunroll.py new file mode 100644 index 00000000..3178d136 --- /dev/null +++ b/snntorch/_neurons/leakyunroll.py @@ -0,0 +1,358 @@ +import torch +import torch.nn as nn + +class LeakyParallel(nn.Module): + """ + A parallel implementation of the Leaky neuron intended to handle arbitrary input dimensions. + All time steps are passed to the input at once. + This implementation uses `torch.nn.RNN` to accelerate the implementation. + + First-order leaky integrate-and-fire neuron model. + Input is assumed to be a current injection. + Membrane potential decays exponentially with rate beta. + For :math:`U[T] > U_{\\rm thr} ⇒ S[T+1] = 1`. + + .. math:: + + U[t+1] = βU[t] + I_{\\rm in}[t+1] + + + * :math:`I_{\\rm in}` - Input current + * :math:`U` - Membrane potential + * :math:`U_{\\rm thr}` - Membrane threshold + * :math:`β` - Membrane potential decay rate + + Several differences between `LeakyParallel` and `Leaky` include: + + * Negative hidden states are clipped due to the forced ReLU operation in RNN + * Linear weights are included in addition to recurrent weights + * `beta` is clipped between [0,1] and cloned to `weight_hh_l` only upon layer initialization. It is unused otherwise + * There is no explicit reset mechanism + * Several functions such as `init_hidden`, `output`, `inhibition`, and `state_quant` are unavailable in `LeakyParallel` + * Only the output spike is returned. Membrane potential is not accessible by default + * RNN uses a hidden matrix of size (num_hidden, num_hidden) to transform the hidden state vector. This would 'leak' the membrane potential between LIF neurons, and so the hidden matrix is forced to a diagonal matrix by default. This can be disabled by setting `weight_hh_enable=True`. + + Example:: + + import torch + import torch.nn as nn + import snntorch as snn + + beta = 0.5 + num_inputs = 784 + num_hidden = 128 + num_outputs = 10 + batch_size = 128 + x = torch.rand((num_steps, batch_size, num_inputs)) + + # Define Network + class Net(nn.Module): + def __init__(self): + super().__init__() + + # initialize layers + self.lif1 = snn.LeakyParallel(input_size=num_inputs, hidden_size=num_hidden) # randomly initialize recurrent weights + self.lif2 = snn.LeakyParallel(input_size=num_hidden, hidden_size=num_outputs, beta=beta, learn_beta=True) # learnable recurrent weights initialized at beta + + def forward(self, x): + spk1 = self.lif1(x) + spk2 = self.lif2(spk1) + return spk2 + + + :param input_size: The number of expected features in the input `x`. The output of a linear layer should be an int, whereas the output of a 2D-convolution should be a tuple with 3 int values + :type input_size: int or tuple + + :param hidden_size: The number of features in the hidden state `h` + :type hidden_size: int + + :param beta: membrane potential decay rate. Clipped between 0 and 1 + during the forward-pass. May be a single-valued tensor (i.e., equal + decay rate for all neurons in a layer), or multi-valued (one weight per + neuron). If left unspecified, then the decay rates will be randomly initialized based on PyTorch's initialization for RNN. Defaults to None + :type beta: float or torch.tensor, optional + + :param bias: If `False`, then the layer does not use bias weights `b_ih` and `b_hh`. Defaults to True + :type bias: Bool, optional + + :param threshold: Threshold for :math:`mem` to reach in order to + generate a spike `S=1`. Defaults to 1 + :type threshold: float, optional + + :param dropout: If non-zero, introduces a Dropout layer on the RNN output with dropout probability equal to dropout. Defaults to 0 + :type dropout: float, optional + + :param spike_grad: Surrogate gradient for the term dS/dU. Defaults to + None (corresponds to ATan surrogate gradient. See + `snntorch.surrogate` for more options) + :type spike_grad: surrogate gradient function from snntorch.surrogate, + optional + + :param surrogate_disable: Disables surrogate gradients regardless of + `spike_grad` argument. Useful for ONNX compatibility. Defaults + to False + :type surrogate_disable: bool, Optional + + :param learn_beta: Option to enable learnable beta. Defaults to False + :type learn_beta: bool, optional + + :param learn_threshold: Option to enable learnable threshold. Defaults + to False + :type learn_threshold: bool, optional + + :param weight_hh_enable: Option to set the hidden matrix to be dense or + diagonal. Diagonal (i.e., False) adheres to how a LIF neuron works. + Dense (True) would allow the membrane potential of one LIF neuron to + influence all others, and follow the RNN default implementation. Defaults to False + :type weight_hh_enable: bool, optional + + + Inputs: \\input_ + - **input_** of shape of shape `(L, H_{in})` for unbatched input, + or `(L, N, H_{in})` containing the features of the input sequence. + + Outputs: spk + - **spk** of shape `(L, batch, input_size)`: tensor containing the + output spikes. + + where: + + `L = sequence length` + + `N = batch size` + + `H_{in} = input_size` + + `H_{out} = hidden_size` + + Learnable Parameters: + - **rnn.weight_ih_l** (torch.Tensor) - the learnable input-hidden weights of shape (hidden_size, input_size) + - **rnn.weight_hh_l** (torch.Tensor) - the learnable hidden-hidden weights of the k-th layer which are sampled from `beta` of shape (hidden_size, hidden_size) + - **bias_ih_l** - the learnable input-hidden bias of the k-th layer, of shape (hidden_size) + - **bias_hh_l** - the learnable hidden-hidden bias of the k-th layer, of shape (hidden_size) + - **threshold** (torch.Tensor) - optional learnable thresholds + must be manually passed in, of shape `1` or`` (input_size). + - **graded_spikes_factor** (torch.Tensor) - optional learnable graded spike factor + + """ + + def __init__( + self, + input_size, + beta=None, + bias=True, + threshold=1.0, + dropout=0.0, + spike_grad=None, + surrogate_disable=False, + learn_beta=False, + learn_threshold=False, + graded_spikes_factor=1.0, + learn_graded_spikes_factor=False, + weight_hh_enable=False, + device=None, + dtype=None, + ): + super(LeakyParallel, self).__init__() + + # if data is T x B x D, then input size is assumed to be D. + # therefore, input size needs to be either Linear: (D), Conv: (C x X x Y) + # everything from dim = 2 must be flattened before being passed into the RNN, and then rolled back up before being output + + self.input_size = input_size + unrolled_input_size = self._process_input() + + # initialize weights: input linear layer is diagonal filled w/ones to prevent linear from doing anything + self.rnn = nn.RNN(unrolled_input_size, unrolled_input_size, num_layers=1, nonlinearity='relu', + bias=bias, batch_first=False, dropout=dropout, device=device, dtype=dtype) + + self._beta_buffer(beta, learn_beta) + self.hidden_size = unrolled_input_size + + if self.beta is not None: + self.beta = self.beta.clamp(0, 1) + + if spike_grad is None: + self.spike_grad = self.ATan.apply + else: + self.spike_grad = spike_grad + + self.weight_ih_disable() # disable input linear layer + self._beta_to_weight_hh() + if weight_hh_enable is False: + # Initial gradient and weights of w_hh are made diagonal + self.weight_hh_enable() + # Register a gradient hook to clamp out non-diagonal matrices in backward pass + if learn_beta: + self.rnn.weight_hh_l0.register_hook(self.grad_hook) + + if not learn_beta: + # Make the weights non-learnable + self.rnn.weight_hh_l0.requires_grad_(False) + + self._threshold_buffer(threshold, learn_threshold) + self._graded_spikes_buffer( + graded_spikes_factor, learn_graded_spikes_factor + ) + + self.surrogate_disable = surrogate_disable + if self.surrogate_disable: + self.spike_grad = self._surrogate_bypass + + def forward(self, input_): + input_ = self.process_tensor(input_) + mem = self.rnn(input_) + # mem[0] contains relu'd outputs, mem[1] contains final hidden state + mem_shift = mem[0] - self.threshold # self.rnn.weight_hh_l0 + spk = self.spike_grad(mem_shift) + spk = spk * self.graded_spikes_factor + spk = self.unprocess_tensor(self, spk) + return spk + + @staticmethod + def _surrogate_bypass(input_): + return (input_ > 0).float() + + @staticmethod + class ATan(torch.autograd.Function): + """ + Surrogate gradient of the Heaviside step function. + + **Forward pass:** Heaviside step function shifted. + + .. math:: + + S=\\begin{cases} 1 & \\text{if U ≥ U$_{\\rm thr}$} \\\\ + 0 & \\text{if U < U$_{\\rm thr}$} + \\end{cases} + + **Backward pass:** Gradient of shifted arc-tan function. + + .. math:: + + S&≈\\frac{1}{π}\\text{arctan}(πU \\frac{α}{2}) \\\\ + \\frac{∂S}{∂U}&=\\frac{1}{π}\ + \\frac{1}{(1+(πU\\frac{α}{2})^2)} + + + :math:`alpha` defaults to 2, and can be modified by calling + ``surrogate.atan(alpha=2)``. + + Adapted from: + + *W. Fang, Z. Yu, Y. Chen, T. Masquelier, T. Huang, Y. Tian (2021) + Incorporating Learnable Membrane Time Constants to Enhance Learning + of Spiking Neural Networks. Proc. IEEE/CVF Int. Conf. Computer + Vision (ICCV), pp. 2661-2671.*""" + + @staticmethod + def forward(ctx, input_, alpha=2.0): + ctx.save_for_backward(input_) + ctx.alpha = alpha + out = (input_ > 0).float() + return out + + @staticmethod + def backward(ctx, grad_output): + (input_,) = ctx.saved_tensors + grad_input = grad_output.clone() + grad = ( + ctx.alpha + / 2 + / (1 + (torch.pi / 2 * ctx.alpha * input_).pow_(2)) + * grad_input + ) + return grad, None + + def _process_input(self): + # Check if the input is an integer + if isinstance(self.input_size, int): + return self.input_size + + # Check if the input is a tuple + elif isinstance(self.input_size, tuple): + product = 1 + for item in self.input_size: + # Ensure each item in the tuple is an integer + if not isinstance(item, int): + raise ValueError("All elements in the tuple must be integers") + product *= item + return product + + else: + raise TypeError("Input must be an integer or a tuple of integers") + + + def weight_hh_enable(self): + mask = torch.eye(self.hidden_size, self.hidden_size) + self.rnn.weight_hh_l0.data = self.rnn.weight_hh_l0.data * mask + + def weight_ih_disable(self): + with torch.no_grad(): + mask = torch.eye(self.input_size, self.input_size) + self.rnn.weight_ih_l0.data = mask + self.rnn.weight_ih_l0.requires_grad_(False) + + def process_tensor(self, input_): + if isinstance(self.input_size, int): + return input_ + elif isinstance(self.input_size, tuple): + return input_.flatten(2) + else: + raise ValueError("input_size must be either an int or a tuple") + + def unprocess_tensor(self, input_): + if isinstance(self.input_size, int): + return input_ + elif isinstance(self.input_size, tuple): + return input_.unflatten(2, self.input_size) + else: + raise ValueError("input_size must be either an int or a tuple") + + + def grad_hook(self, grad): + device = grad.device + # Create a mask that is 1 on the diagonal and 0 elsewhere + mask = torch.eye(self.hidden_size, self.hidden_size, device=device) + # Use the mask to zero out non-diagonal elements of the gradient + return grad * mask + + def _beta_to_weight_hh(self): + with torch.no_grad(): + if self.beta is not None: + # Set all weights to the scalar value of self.beta + if isinstance(self.beta, float) or isinstance(self.beta, int): + self.rnn.weight_hh_l0.fill_(self.beta) + elif isinstance(self.beta, torch.Tensor) or isinstance(self.beta, torch.FloatTensor): + if len(self.beta) == 1: + self.rnn.weight_hh_l0.fill_(self.beta[0]) + elif len(self.beta) == self.hidden_size: + # Replace each value with the corresponding value in self.beta + for i in range(self.hidden_size): + self.rnn.weight_hh_l0.data[i].fill_(self.beta[i]) + else: + raise ValueError("Beta must be either a single value or of length 'hidden_size'.") + + def _beta_buffer(self, beta, learn_beta): + if not isinstance(beta, torch.Tensor): + if beta is not None: + beta = torch.as_tensor([beta]) # TODO: or .tensor() if no copy + self.register_buffer("beta", beta) + + def _graded_spikes_buffer( + self, graded_spikes_factor, learn_graded_spikes_factor + ): + if not isinstance(graded_spikes_factor, torch.Tensor): + graded_spikes_factor = torch.as_tensor(graded_spikes_factor) + if learn_graded_spikes_factor: + self.graded_spikes_factor = nn.Parameter(graded_spikes_factor) + else: + self.register_buffer("graded_spikes_factor", graded_spikes_factor) + + def _threshold_buffer(self, threshold, learn_threshold): + if not isinstance(threshold, torch.Tensor): + threshold = torch.as_tensor(threshold) + if learn_threshold: + self.threshold = nn.Parameter(threshold) + else: + self.register_buffer("threshold", threshold) \ No newline at end of file