diff --git a/snntorch/_neurons/leaky.py b/snntorch/_neurons/leaky.py index f07a20ce..00380fcb 100644 --- a/snntorch/_neurons/leaky.py +++ b/snntorch/_neurons/leaky.py @@ -160,6 +160,7 @@ def __init__( ) self._init_mem() + self.init_hidden = init_hidden if self.reset_mechanism_val == 0: # reset by subtraction self.state_function = self._base_sub @@ -167,6 +168,7 @@ def __init__( self.state_function = self._base_zero elif self.reset_mechanism_val == 2: # no reset, pure integration self.state_function = self._base_int + def _init_mem(self): mem = torch.zeros(1) @@ -183,6 +185,11 @@ def init_leaky(self): def forward(self, input_, mem=None): if not mem == None: self.mem = mem + + if self.init_hidden and not mem == None: + raise TypeError( + "mem should not be passed as an argument while `init_hidden=True`" + ) if not self.mem.shape == input_.shape: self.mem = torch.zeros_like(input_, device=self.mem.device)