Skip to content

Commit

Permalink
Merge pull request #317 from gekkom/fix-device
Browse files Browse the repository at this point in the history
Fix model.to(device)
  • Loading branch information
jeshraghian authored Apr 21, 2024
2 parents e985c0c + 3b4f023 commit 202f008
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 34 deletions.
15 changes: 10 additions & 5 deletions snntorch/_neurons/alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,13 @@ def __init__(
self.state_function = self._base_int

def _init_mem(self):
self.syn_exc = torch.zeros(1)
self.syn_inh = torch.zeros(1)
self.mem = torch.zeros(1)
syn_exc = torch.zeros(0)
syn_inh = torch.zeros(0)
mem = torch.zeros(0)

self.register_buffer("syn_exc", syn_exc, False)
self.register_buffer("syn_inh", syn_inh, False)
self.register_buffer("mem", mem, False)

def reset_mem(self):
self.syn_exc = torch.zeros_like(
Expand All @@ -142,10 +146,11 @@ def reset_mem(self):
)
self.mem = torch.zeros_like(self.mem, device=self.mem.device)

return self.syn_exc, self.syn_inh, self.mem

def init_alpha(self):
"""Deprecated, use :class:`Alpha.reset_mem` instead"""
self.reset_mem()
return self.syn_exc, self.syn_inh, self.mem
return self.reset_mem()

def forward(self, input_, syn_exc=None, syn_inh=None, mem=None):

Expand Down
7 changes: 4 additions & 3 deletions snntorch/_neurons/lapicque.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,15 +226,16 @@ def __init__(
self.state_function = self._base_int

def _init_mem(self):
self.mem = torch.zeros(1)
mem = torch.zeros(0)
self.register_buffer("mem", mem, False)

def reset_mem(self):
self.mem = torch.zeros_like(self.mem, device=self.mem.device)
return self.mem

def init_lapicque(self):
"""Deprecated, use :class:`Lapicque.reset_mem` instead"""
self.reset_mem()
return self.mem
return self.reset_mem()

def forward(self, input_, mem=None):

Expand Down
7 changes: 4 additions & 3 deletions snntorch/_neurons/leaky.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,16 @@ def __init__(
self.reset_delay = reset_delay

def _init_mem(self):
self.mem = torch.zeros(1)
mem = torch.zeros(0)
self.register_buffer("mem", mem, False)

def reset_mem(self):
self.mem = torch.zeros_like(self.mem, device=self.mem.device)
return self.mem

def init_leaky(self):
"""Deprecated, use :class:`Leaky.reset_mem` instead"""
self.reset_mem()
return self.mem
return self.reset_mem()

def forward(self, input_, mem=None):

Expand Down
11 changes: 7 additions & 4 deletions snntorch/_neurons/rleaky.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,17 +292,20 @@ def __init__(
self.reset_delay = reset_delay

def _init_mem(self):
self.spk = torch.zeros(1)
self.mem = torch.zeros(1)
spk = torch.zeros(0)
mem = torch.zeros(0)

self.register_buffer("spk", spk, False)
self.register_buffer("mem", mem, False)

def reset_mem(self):
self.spk = torch.zeros_like(self.spk, device=self.spk.device)
self.mem = torch.zeros_like(self.mem, device=self.mem.device)
return self.spk, self.mem

def init_rleaky(self):
"""Deprecated, use :class:`RLeaky.reset_mem` instead"""
self.reset_mem()
return self.spk, self.mem
return self.reset_mem()

def forward(self, input_, spk=None, mem=None):

Expand Down
18 changes: 12 additions & 6 deletions snntorch/_neurons/rsynaptic.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,19 +307,23 @@ def __init__(
self.reset_delay = reset_delay

def _init_mem(self):
self.spk = torch.zeros(1)
self.syn = torch.zeros(1)
self.mem = torch.zeros(1)
spk = torch.zeros(0)
syn = torch.zeros(0)
mem = torch.zeros(0)

self.register_buffer("spk", spk, False)
self.register_buffer("syn", syn, False)
self.register_buffer("mem", mem, False)

def reset_mem(self):
self.spk = torch.zeros_like(self.spk, device=self.spk.device)
self.syn = torch.zeros_like(self.syn, device=self.syn.device)
self.mem = torch.zeros_like(self.mem, device=self.mem.device)
return self.spk, self.syn, self.mem

def init_rsynaptic(self):
"""Deprecated, use :class:`RSynaptic.reset_mem` instead"""
self.reset_mem()
return self.spk, self.syn, self.mem
return self.reset_mem()

def forward(self, input_, spk=None, syn=None, mem=None):
if not spk == None:
Expand All @@ -331,7 +335,9 @@ def forward(self, input_, spk=None, syn=None, mem=None):
if not mem == None:
self.mem = mem

if self.init_hidden and (not spk == None or not syn == None or not mem == None):
if self.init_hidden and (
not spk == None or not syn == None or not mem == None
):
raise TypeError(
"When `init_hidden=True`, RSynaptic expects 1 input argument."
)
Expand Down
13 changes: 8 additions & 5 deletions snntorch/_neurons/sconv2dlstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,17 +273,20 @@ def __init__(
)

def _init_mem(self):
self.syn = torch.zeros(1)
self.mem = torch.zeros(1)
syn = torch.zeros(0)
mem = torch.zeros(0)

self.register_buffer("syn", syn, False)
self.register_buffer("mem", mem, False)

def reset_mem(self):
self.syn = torch.zeros_like(self.syn, device=self.syn.device)
self.mem = torch.zeros_like(self.mem, device=self.mem.device)
return self.syn, self.mem

def init_sconv2dlstm(self):
"""Deprecated, use :class:`SConv2dLSTM.reset_mem` instead"""
self.reset_mem()
return self.syn, self.mem
return self.reset_mem()

def forward(self, input_, syn=None, mem=None):
if not syn == None:
Expand All @@ -296,7 +299,7 @@ def forward(self, input_, syn=None, mem=None):
raise TypeError(
"`mem` or `syn` should not be passed as an argument while `init_hidden=True`"
)

size = input_.size()
correct_shape = (size[0], self.out_channels, size[2], size[3])
if not self.syn.shape == correct_shape:
Expand Down
11 changes: 7 additions & 4 deletions snntorch/_neurons/slstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,17 +203,20 @@ def __init__(
)

def _init_mem(self):
self.syn = torch.zeros(1)
self.mem = torch.zeros(1)
syn = torch.zeros(0)
mem = torch.zeros(0)

self.register_buffer("syn", syn, False)
self.register_buffer("mem", mem, False)

def reset_mem(self):
self.syn = torch.zeros_like(self.syn, device=self.syn.device)
self.mem = torch.zeros_like(self.mem, device=self.mem.device)
return self.syn, self.mem

def init_slstm(self):
"""Deprecated, use :class:`SLSTM.reset_mem` instead"""
self.reset_mem()
return self.syn, self.mem
return self.reset_mem()

def forward(self, input_, syn=None, mem=None):
if not syn == None:
Expand Down
11 changes: 7 additions & 4 deletions snntorch/_neurons/synaptic.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,17 +198,20 @@ def __init__(
self.reset_delay = reset_delay

def _init_mem(self):
self.syn = torch.zeros(1)
self.mem = torch.zeros(1)
syn = torch.zeros(0)
mem = torch.zeros(0)

self.register_buffer("syn", syn, False)
self.register_buffer("mem", mem, False)

def reset_mem(self):
self.syn = torch.zeros_like(self.syn, device=self.syn.device)
self.mem = torch.zeros_like(self.mem, device=self.mem.device)
return self.syn, self.mem

def init_synaptic(self):
"""Deprecated, use :class:`Synaptic.reset_mem` instead"""
self.reset_mem()
return self.syn, self.mem
return self.reset_mem()

def forward(self, input_, syn=None, mem=None):

Expand Down

0 comments on commit 202f008

Please sign in to comment.