diff --git a/snntorch/export_nir.py b/snntorch/export_nir.py index 4f0a1614..471b9e92 100644 --- a/snntorch/export_nir.py +++ b/snntorch/export_nir.py @@ -27,6 +27,7 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: beta = module.beta.detach().numpy() vthr = module.threshold.detach().numpy() + vthr = np.array([vthr]) if isinstance(vthr, (int, float)) else vthr tau_mem = dt / (1 - beta) r = tau_mem / dt v_leak = np.zeros_like(beta) @@ -56,6 +57,7 @@ def _extract_snntorch_module(module: torch.nn.Module) -> Optional[nir.NIRNode]: alpha = module.alpha.detach().numpy() beta = module.beta.detach().numpy() vthr = module.threshold.detach().numpy() + vthr = np.array([vthr]) if isinstance(vthr, (int, float)) else vthr tau_syn = dt / (1 - alpha) tau_mem = dt / (1 - beta)