diff --git a/examples/keras/mnist_spiking_cnn.py b/examples/keras/mnist_spiking_cnn.py
index 7737a41..ce66ab2 100644
--- a/examples/keras/mnist_spiking_cnn.py
+++ b/examples/keras/mnist_spiking_cnn.py
@@ -13,6 +13,7 @@
from keras.utils import np_utils
import nengo
+from nengo_extras.convnet import PresentImages
from nengo_extras.keras import (
load_model_pair, save_model_pair, SequentialNetwork, SoftLIF)
from nengo_extras.gui import image_display_function
@@ -88,19 +89,13 @@
model = nengo.Network()
with model:
- u = nengo.Node(nengo.processes.PresentInput(X_test, presentation_time))
+ u = nengo.Node(PresentImages(X_test, presentation_time))
seq = SequentialNetwork(kmodel, synapse=nengo.synapses.Alpha(0.005))
nengo.Connection(u, seq.input, synapse=None)
input_p = nengo.Probe(u)
output_p = nengo.Probe(seq.output)
- # --- image display
- image_shape = kmodel.input_shape[1:]
- display_f = image_display_function(image_shape)
- display_node = nengo.Node(display_f, size_in=u.size_out)
- nengo.Connection(u, display_node, synapse=None)
-
# --- output spa display
vocab_names = ['ZERO', 'ONE', 'TWO', 'THREE', 'FOUR',
'FIVE', 'SIX', 'SEVEN', 'EIGHT', 'NINE']
diff --git a/nengo_extras/convnet.py b/nengo_extras/convnet.py
index d6a1bad..5ba1f45 100644
--- a/nengo_extras/convnet.py
+++ b/nengo_extras/convnet.py
@@ -1,5 +1,6 @@
import numpy as np
+import nengo
from nengo.exceptions import ValidationError
from nengo.processes import Process
from nengo.params import (EnumParam, NdarrayParam, Parameter, TupleParam,
@@ -253,6 +254,70 @@ def step_pool2d(t, x):
return step_pool2d
+class PresentImages(nengo.processes.PresentInput):
+ image_shape = ShapeParam('image_shape', length=3, low=1)
+
+ def __init__(self, images, presentation_time, **kwargs):
+ self.image_shape = images.shape[1:]
+ super(PresentImages, self).__init__(
+ images, presentation_time, **kwargs)
+
+ def _nengo_html_function_(self, t, x):
+ import base64
+ import PIL
+ import cStringIO
+
+ values = x.reshape(self.image_shape)
+ values = values.transpose((1, 2, 0)) # colour channel last
+ values = values * 255.
+ values = values.astype('uint8')
+
+ if values.shape[-1] == 1:
+ values = values[:, :, 0]
+
+ png = PIL.Image.fromarray(values)
+ buffer = cStringIO.StringIO()
+ png.save(buffer, format="PNG")
+ img_str = base64.b64encode(buffer.getvalue())
+
+ return '''
+ ''' % (''.join(img_str))
+
+ # def make_html_function(self, size):
+ # import base64
+ # import PIL
+ # import cStringIO
+ # image_shape = self.image_shape
+ # assert np.prod(image_shape) == size
+
+ # def html_presentimages(t, x):
+ # values = x.reshape(image_shape)
+ # values = values.transpose((1, 2, 0)) # colour channel last
+ # values = values * 255.
+ # values = values.astype('uint8')
+
+ # if values.shape[-1] == 1:
+ # values = values[:, :, 0]
+
+ # png = PIL.Image.fromarray(values)
+ # buffer = cStringIO.StringIO()
+ # png.save(buffer, format="PNG")
+ # img_str = base64.b64encode(buffer.getvalue())
+
+ # return '''
+ # ''' % (''.join(img_str))
+
+ # return html_presentimages
+
+
def softmax(x, axis=None):
"""Stable softmax function"""
ex = np.exp(x - x.max(axis=axis, keepdims=True))