-
Notifications
You must be signed in to change notification settings - Fork 0
/
plotting.py
52 lines (43 loc) · 1.58 KB
/
plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import matplotlib.pyplot as plt
import numpy as np
def plot_images(images, cls_true, class_names, cls_pred=None, smooth=True):
assert len(images) == len(cls_true)
# Create figure with sub-plots.
fig, axes = plt.subplots(3, 3)
# Adjust vertical spacing.
if cls_pred is None:
hspace = 0.3
else:
hspace = 0.6
fig.subplots_adjust(hspace=hspace, wspace=0.3)
# Interpolation type.
if smooth:
interpolation = 'spline16'
else:
interpolation = 'nearest'
for i, ax in enumerate(axes.flat):
# There may be less than 9 images, ensure it doesn't crash.
if i < len(images):
# Plot image.
ax.imshow(images[i],
interpolation=interpolation)
# Name of the true class.
cls_true_name = class_names[cls_true[i]]
# Show true and predicted classes.
if cls_pred is None:
xlabel = "True: {0}".format(cls_true_name)
else:
# Name of the predicted class.
cls_pred_name = class_names[cls_pred[i]]
xlabel = "True: {0}\nPred: {1}".format(cls_true_name, cls_pred_name)
# Show the classes as the label on the x-axis.
ax.set_xlabel(xlabel)
# Remove ticks from the plot.
ax.set_xticks([])
ax.set_yticks([])
# Ensure the plot is shown correctly with multiple plots
# in a single Notebook cell.
plt.show()
def load_images(image_paths):
images=[plt.imread(path) for path in image_paths]
return np.asarray(images)