-
Notifications
You must be signed in to change notification settings - Fork 16
/
hinton.py
58 lines (45 loc) · 1.59 KB
/
hinton.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
Demo of a function to create Hinton diagrams.
Hinton diagrams are useful for visualizing the values of a 2D array (e.g.
a weight matrix): Positive and negative values are represented by white and
black squares, respectively, and the size of each square represents the
magnitude of each value.
Initial idea from David Warde-Farley on the SciPy Cookbook
"""
import vizualizator
import cv2
import numpy as np
import matplotlib.pyplot as plt
def hinton(matrix, max_weight=None, ax=None):
"""Draw Hinton diagram for visualizing a weight matrix."""
ax = ax if ax is not None else plt.gca()
if not max_weight:
max_weight = 2 ** np.ceil(np.log(np.abs(matrix).max()) / np.log(2))
ax.patch.set_facecolor('gray')
ax.set_aspect('equal', 'box')
ax.xaxis.set_major_locator(plt.NullLocator())
ax.yaxis.set_major_locator(plt.NullLocator())
for (x, y), w in np.ndenumerate(matrix):
color = 'white' if w > 0 else 'black'
size = np.sqrt(np.abs(w) / max_weight)
rect = plt.Rectangle([x - size / 2, y - size / 2], size, size,
facecolor=color, edgecolor=color)
ax.add_patch(rect)
ax.autoscale_view()
ax.invert_yaxis()
return ax
data = vizualizator.load_dummy_data()
average_image = vizualizator.extract_average_color(data)
r, g, b = vizualizator.extract_rgb(data)
fig = plt.figure()
fig.suptitle("Weights of first hidden layer", fontsize=16)
plt.subplot(1, 3, 1)
hinton(r)
plt.title('Red channel')
plt.subplot(1, 3, 2)
hinton(g)
plt.title('Green channel')
plt.subplot(1, 3, 3)
hinton(b)
plt.title('Blue channel')
plt.show()