diff --git a/docs/snntorch.functional.rst b/docs/snntorch.functional.rst index 6b9f0a40..c60d2242 100644 --- a/docs/snntorch.functional.rst +++ b/docs/snntorch.functional.rst @@ -57,6 +57,14 @@ State Quantization :members: :undoc-members: :show-inheritance: + +STDP Learner +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. automodule:: snntorch.functional.stdp_learner + :members: + :undoc-members: + :show-inheritance: Probe ^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/snntorch/functional/stdp_learner.py b/snntorch/functional/stdp_learner.py index a2960542..84bc1e36 100644 --- a/snntorch/functional/stdp_learner.py +++ b/snntorch/functional/stdp_learner.py @@ -20,6 +20,11 @@ def stdp_linear_single_step( f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x, ): + """ + Single step of STDP learning rule for Linear layer. + + """ + if trace_pre is None: trace_pre = 0.0 @@ -55,6 +60,11 @@ def mstdp_linear_single_step( f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x, ): + """ + Single step of mSTDP learning rule for Linear layer. + + """ + if trace_pre is None: trace_pre = 0.0 @@ -88,6 +98,11 @@ def mstdpet_linear_single_step( f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x, ): + """ + Single step of mSTDP learning rule with Eligibility Trace for Linear layer. + + """ + if trace_pre is None: trace_pre = 0.0 @@ -115,6 +130,11 @@ def stdp_conv2d_single_step( f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x, ): + """ + Single step of STDP learning rule for Conv2d layer. + + """ + if conv.dilation != (1, 1): raise NotImplementedError( "STDP with dilation != 1 for Conv2d has not been implemented!" @@ -198,6 +218,11 @@ def stdp_conv1d_single_step( f_pre: Callable = lambda x: x, f_post: Callable = lambda x: x, ): + """ + Single step of STDP learning rule for Conv1d layer. + + """ + if conv.dilation != (1,): raise NotImplementedError( "STDP with dilation != 1 for Conv1d has not been implemented!" @@ -285,19 +310,49 @@ def __init__( self.trace_post = None def reset(self): + """ + Reset the recorded data in the monitors. + + """ + super(STDPLearner, self).reset() self.in_spike_monitor.clear_recorded_data() self.out_spike_monitor.clear_recorded_data() def disable(self): + """ + Disable the recording of the data in the monitors. + + """ + self.in_spike_monitor.disable() self.out_spike_monitor.disable() def enable(self): + """ + Enable the recording of the data in the monitors. + + """ + self.in_spike_monitor.enable() self.out_spike_monitor.enable() def step(self, on_grad: bool = True, scale: float = 1.0): + """ + Perform a single step of STDP learning rule. + + :param on_grad: If True, the delta_w is added to the weight.grad of the synapse. + If False, the delta_w is returned. + :type on_grad: bool + + :param scale: Scaling factor for the delta_w. + :type scale: float + + :return: delta_w if on_grad is False. + :rtype: torch.Tensor + + """ + length = self.in_spike_monitor.records.__len__() delta_w = None