From d09514c64055f3d314f4f46dbd8f18523e5981fc Mon Sep 17 00:00:00 2001 From: Lakshay Date: Mon, 29 Apr 2024 12:07:12 +0530 Subject: [PATCH] added predict api --- spkeras/models.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/spkeras/models.py b/spkeras/models.py index 1b3b579..a7ade21 100644 --- a/spkeras/models.py +++ b/spkeras/models.py @@ -169,7 +169,7 @@ def convert(self, mdl,x_train,thresholding=0.5,scaling_factor=1,method=0,timeste new_mdl['config']['output_layers'] = [[inbound_nodes, 0, 0]] new_mdl = json.dumps(new_mdl) new_model = model_from_json(new_mdl, - custom_objects={'SpikeActivation':SpikeActivation}) + custom_objects={'SpikeActivation':SpikeActivation}) input_shape = model.layers[0].input_shape @@ -408,6 +408,22 @@ def evaluate(self,x_test,y_test,timesteps=256,thresholding=0.5,scaling_factor=1, _x_test = x_test*fix if fix > 0 else np.floor(x_test*self.timesteps) return self.model.evaluate(_x_test,y_test) + def predict(self, features, timesteps=256, thresholding=0.5, scaling_factor=1, + spike_ext=0, noneloss=False, sf=None, fix=0): + import numpy as np + self.timesteps = timesteps + self.thresholding = thresholding + self.scaling_factor = scaling_factor + self.spike_ext = spike_ext + self.noneloss = noneloss + self.model = self.chts_model(timesteps=timesteps,thresholding=thresholding, + scaling_factor=scaling_factor, + spike_ext=spike_ext,noneloss=noneloss,sf=sf) + + self.get_config() + _features = features*fix if fix > 0 else np.floor(features*self.timesteps) + return self.model.predict(features) + def chts_model(self,timesteps=256,thresholding=0.5,scaling_factor=1,spike_ext=0,noneloss=False,sf=None): #method: 0:threshold norm 1:weight norm from tensorflow.keras.models import Sequential, model_from_json