diff --git a/examples/keras/mnist_spiking_cnn.py b/examples/keras/mnist_spiking_cnn.py index 7737a41..ce66ab2 100644 --- a/examples/keras/mnist_spiking_cnn.py +++ b/examples/keras/mnist_spiking_cnn.py @@ -13,6 +13,7 @@ from keras.utils import np_utils import nengo +from nengo_extras.convnet import PresentImages from nengo_extras.keras import ( load_model_pair, save_model_pair, SequentialNetwork, SoftLIF) from nengo_extras.gui import image_display_function @@ -88,19 +89,13 @@ model = nengo.Network() with model: - u = nengo.Node(nengo.processes.PresentInput(X_test, presentation_time)) + u = nengo.Node(PresentImages(X_test, presentation_time)) seq = SequentialNetwork(kmodel, synapse=nengo.synapses.Alpha(0.005)) nengo.Connection(u, seq.input, synapse=None) input_p = nengo.Probe(u) output_p = nengo.Probe(seq.output) - # --- image display - image_shape = kmodel.input_shape[1:] - display_f = image_display_function(image_shape) - display_node = nengo.Node(display_f, size_in=u.size_out) - nengo.Connection(u, display_node, synapse=None) - # --- output spa display vocab_names = ['ZERO', 'ONE', 'TWO', 'THREE', 'FOUR', 'FIVE', 'SIX', 'SEVEN', 'EIGHT', 'NINE'] diff --git a/nengo_extras/convnet.py b/nengo_extras/convnet.py index d6a1bad..5ba1f45 100644 --- a/nengo_extras/convnet.py +++ b/nengo_extras/convnet.py @@ -1,5 +1,6 @@ import numpy as np +import nengo from nengo.exceptions import ValidationError from nengo.processes import Process from nengo.params import (EnumParam, NdarrayParam, Parameter, TupleParam, @@ -253,6 +254,70 @@ def step_pool2d(t, x): return step_pool2d +class PresentImages(nengo.processes.PresentInput): + image_shape = ShapeParam('image_shape', length=3, low=1) + + def __init__(self, images, presentation_time, **kwargs): + self.image_shape = images.shape[1:] + super(PresentImages, self).__init__( + images, presentation_time, **kwargs) + + def _nengo_html_function_(self, t, x): + import base64 + import PIL + import cStringIO + + values = x.reshape(self.image_shape) + values = values.transpose((1, 2, 0)) # colour channel last + values = values * 255. + values = values.astype('uint8') + + if values.shape[-1] == 1: + values = values[:, :, 0] + + png = PIL.Image.fromarray(values) + buffer = cStringIO.StringIO() + png.save(buffer, format="PNG") + img_str = base64.b64encode(buffer.getvalue()) + + return ''' + + + ''' % (''.join(img_str)) + + # def make_html_function(self, size): + # import base64 + # import PIL + # import cStringIO + # image_shape = self.image_shape + # assert np.prod(image_shape) == size + + # def html_presentimages(t, x): + # values = x.reshape(image_shape) + # values = values.transpose((1, 2, 0)) # colour channel last + # values = values * 255. + # values = values.astype('uint8') + + # if values.shape[-1] == 1: + # values = values[:, :, 0] + + # png = PIL.Image.fromarray(values) + # buffer = cStringIO.StringIO() + # png.save(buffer, format="PNG") + # img_str = base64.b64encode(buffer.getvalue()) + + # return ''' + # + # + # ''' % (''.join(img_str)) + + # return html_presentimages + + def softmax(x, axis=None): """Stable softmax function""" ex = np.exp(x - x.max(axis=axis, keepdims=True))