forked from WuJie1010/Facial-Expression-Recognition.Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot_fer2013_confusion_matrix.py
124 lines (99 loc) · 3.81 KB
/
plot_fer2013_confusion_matrix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
plot confusion_matrix of PublicTest and PrivateTest
"""
import itertools
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import argparse
from fer import FER2013
from torch.autograd import Variable
import torchvision
import transforms as transforms
from sklearn.metrics import confusion_matrix
from models import *
parser = argparse.ArgumentParser(description='PyTorch Fer2013 CNN Training')
parser.add_argument('--model', type=str, default='VGG19', help='CNN architecture')
parser.add_argument('--dataset', type=str, default='FER2013', help='CNN architecture')
parser.add_argument('--split', type=str, default='PrivateTest', help='split')
opt = parser.parse_args()
cut_size = 44
transform_test = transforms.Compose([
transforms.TenCrop(cut_size),
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
])
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title, fontsize=16)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label', fontsize=18)
plt.xlabel('Predicted label', fontsize=18)
plt.tight_layout()
class_names = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']
# Model
if opt.model == 'VGG19':
net = VGG('VGG19')
elif opt.model == 'Resnet18':
net = ResNet18()
path = os.path.join(opt.dataset + '_' + opt.model)
checkpoint = torch.load(os.path.join(path, opt.split + '_model.t7'))
net.load_state_dict(checkpoint['net'])
net.cuda()
net.eval()
Testset = FER2013(split = opt.split, transform=transform_test)
Testloader = torch.utils.data.DataLoader(Testset, batch_size=128, shuffle=False, num_workers=1)
correct = 0
total = 0
all_target = []
for batch_idx, (inputs, targets) in enumerate(Testloader):
bs, ncrops, c, h, w = np.shape(inputs)
inputs = inputs.view(-1, c, h, w)
inputs, targets = inputs.cuda(), targets.cuda()
inputs, targets = Variable(inputs, volatile=True), Variable(targets)
outputs = net(inputs)
outputs_avg = outputs.view(bs, ncrops, -1).mean(1) # avg over crops
_, predicted = torch.max(outputs_avg.data, 1)
total += targets.size(0)
correct += predicted.eq(targets.data).cpu().sum()
if batch_idx == 0:
all_predicted = predicted
all_targets = targets
else:
all_predicted = torch.cat((all_predicted, predicted),0)
all_targets = torch.cat((all_targets, targets),0)
acc = 100. * correct / total
print("accuracy: %0.3f" % acc)
# Compute confusion matrix
matrix = confusion_matrix(all_targets.data.cpu().numpy(), all_predicted.cpu().numpy())
np.set_printoptions(precision=2)
# Plot normalized confusion matrix
plt.figure(figsize=(10, 8))
plot_confusion_matrix(matrix, classes=class_names, normalize=True,
title= opt.split+' Confusion Matrix (Accuracy: %0.3f%%)' %acc)
plt.savefig(os.path.join(path, opt.split + '_cm.png'))
plt.close()