Skip to content

Commit

Permalink
WIP: Added PresentImages process to show images in GUI
Browse files Browse the repository at this point in the history
TODO:
- remove dead code
- merge nengo/nengo-gui#755
  • Loading branch information
hunse committed Jun 7, 2016
1 parent 6a7d6e0 commit ae7ab6f
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 32 deletions.
34 changes: 2 additions & 32 deletions examples/keras/mnist_spiking_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from keras.utils import np_utils

import nengo
from nengo_extras.convnet import PresentImages
from nengo_extras.keras import (
load_model_pair, save_model_pair, SequentialNetwork, SoftLIF)

Expand Down Expand Up @@ -88,44 +89,13 @@

model = nengo.Network()
with model:
u = nengo.Node(nengo.processes.PresentInput(X_test, presentation_time))
u = nengo.Node(PresentImages(X_test, presentation_time))
seq = SequentialNetwork(kmodel, synapse=nengo.synapses.Alpha(0.005))
nengo.Connection(u, seq.input, synapse=None)

input_p = nengo.Probe(u)
output_p = nengo.Probe(seq.output)

# --- image display
input_shape = kmodel.input_shape[1:]

def display_func(t, x, input_shape=input_shape):
import base64
import PIL
import cStringIO

values = x.reshape(input_shape)
values = values.transpose((1, 2, 0))
values = values * 255.
values = values.astype('uint8')

if values.shape[-1] == 1:
values = values[:, :, 0]

png = PIL.Image.fromarray(values)
buffer = cStringIO.StringIO()
png.save(buffer, format="PNG")
img_str = base64.b64encode(buffer.getvalue())

display_func._nengo_html_ = '''
<svg width="100%%" height="100%%" viewbox="0 0 100 100">
<image width="100%%" height="100%%"
xlink:href="data:image/png;base64,%s"
style="image-rendering: pixelated;">
</svg>''' % (''.join(img_str))

display_node = nengo.Node(display_func, size_in=u.size_out)
nengo.Connection(u, display_node, synapse=None)

# --- output spa display
vocab_names = ['ZERO', 'ONE', 'TWO', 'THREE', 'FOUR',
'FIVE', 'SIX', 'SEVEN', 'EIGHT', 'NINE']
Expand Down
65 changes: 65 additions & 0 deletions nengo_extras/convnet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np

import nengo
from nengo.exceptions import ValidationError
from nengo.processes import Process
from nengo.params import (EnumParam, NdarrayParam, Parameter, TupleParam,
Expand Down Expand Up @@ -211,6 +212,70 @@ def step_pool2d(t, x):
return step_pool2d


class PresentImages(nengo.processes.PresentInput):
image_shape = ShapeParam('image_shape', length=3, low=1)

def __init__(self, images, presentation_time, **kwargs):
self.image_shape = images.shape[1:]
super(PresentImages, self).__init__(
images, presentation_time, **kwargs)

def _nengo_html_function_(self, t, x):
import base64
import PIL
import cStringIO

values = x.reshape(self.image_shape)
values = values.transpose((1, 2, 0)) # colour channel last
values = values * 255.
values = values.astype('uint8')

if values.shape[-1] == 1:
values = values[:, :, 0]

png = PIL.Image.fromarray(values)
buffer = cStringIO.StringIO()
png.save(buffer, format="PNG")
img_str = base64.b64encode(buffer.getvalue())

return '''
<svg width="100%%" height="100%%" viewbox="0 0 100 100">
<image width="100%%" height="100%%"
xlink:href="data:image/png;base64,%s"
style="image-rendering: pixelated;">
</svg>''' % (''.join(img_str))

# def make_html_function(self, size):
# import base64
# import PIL
# import cStringIO
# image_shape = self.image_shape
# assert np.prod(image_shape) == size

# def html_presentimages(t, x):
# values = x.reshape(image_shape)
# values = values.transpose((1, 2, 0)) # colour channel last
# values = values * 255.
# values = values.astype('uint8')

# if values.shape[-1] == 1:
# values = values[:, :, 0]

# png = PIL.Image.fromarray(values)
# buffer = cStringIO.StringIO()
# png.save(buffer, format="PNG")
# img_str = base64.b64encode(buffer.getvalue())

# return '''
# <svg width="100%%" height="100%%" viewbox="0 0 100 100">
# <image width="100%%" height="100%%"
# xlink:href="data:image/png;base64,%s"
# style="image-rendering: pixelated;">
# </svg>''' % (''.join(img_str))

# return html_presentimages


def softmax(x, axis=None):
"""Stable softmax function"""
ex = np.exp(x - x.max(axis=axis, keepdims=True))
Expand Down

0 comments on commit ae7ab6f

Please sign in to comment.