Skip to content

Commit

Permalink
fix leaky error where init_hidden=True and membrane is explicit
Browse files Browse the repository at this point in the history
  • Loading branch information
jeshraghian committed Jan 25, 2024
1 parent 9ca17c2 commit 8f047ee
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions snntorch/_neurons/leaky.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,15 @@ def __init__(
)

self._init_mem()
self.init_hidden = init_hidden

if self.reset_mechanism_val == 0: # reset by subtraction
self.state_function = self._base_sub
elif self.reset_mechanism_val == 1: # reset to zero
self.state_function = self._base_zero
elif self.reset_mechanism_val == 2: # no reset, pure integration
self.state_function = self._base_int


def _init_mem(self):
mem = torch.zeros(1)
Expand All @@ -183,6 +185,11 @@ def init_leaky(self):
def forward(self, input_, mem=None):
if not mem == None:
self.mem = mem

if self.init_hidden and not mem == None:
raise TypeError(
"mem should not be passed as an argument while `init_hidden=True`"
)

if not self.mem.shape == input_.shape:
self.mem = torch.zeros_like(input_, device=self.mem.device)
Expand Down

0 comments on commit 8f047ee

Please sign in to comment.