Skip to content

Commit

Permalink
ensure vthr is numpy array, not scalar
Browse files Browse the repository at this point in the history
  • Loading branch information
stevenabreu7 authored Jun 20, 2024
1 parent 2cb5bcb commit 49a4b35
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions snntorch/export_nir.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]:

beta = module.beta.detach().numpy()
vthr = module.threshold.detach().numpy()
vthr = np.array([vthr]) if isinstance(vthr, (int, float)) else vthr
tau_mem = dt / (1 - beta)
r = tau_mem / dt
v_leak = np.zeros_like(beta)
Expand Down Expand Up @@ -56,6 +57,7 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]:
alpha = module.alpha.detach().numpy()
beta = module.beta.detach().numpy()
vthr = module.threshold.detach().numpy()
vthr = np.array([vthr]) if isinstance(vthr, (int, float)) else vthr

tau_syn = dt / (1 - alpha)
tau_mem = dt / (1 - beta)
Expand Down

0 comments on commit 49a4b35

Please sign in to comment.