1 Star 1 Fork 0

shenhao2954988368 / Facial-Expression-Recognition.Pytorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
plot_fer2013_confusion_matrix.py 3.81 KB
一键复制 编辑 原始数据 按行查看 历史
WuJie 提交于 2018-07-15 14:45 . Add files via upload
"""
plot confusion_matrix of PublicTest and PrivateTest
"""
import itertools
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import argparse
from fer import FER2013
from torch.autograd import Variable
import torchvision
import transforms as transforms
from sklearn.metrics import confusion_matrix
from models import *
parser = argparse.ArgumentParser(description='PyTorch Fer2013 CNN Training')
parser.add_argument('--model', type=str, default='VGG19', help='CNN architecture')
parser.add_argument('--dataset', type=str, default='FER2013', help='CNN architecture')
parser.add_argument('--split', type=str, default='PrivateTest', help='split')
opt = parser.parse_args()
cut_size = 44
transform_test = transforms.Compose([
transforms.TenCrop(cut_size),
transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
])
def plot_confusion_matrix(cm, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("Normalized confusion matrix")
else:
print('Confusion matrix, without normalization')
print(cm)
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title, fontsize=16)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label', fontsize=18)
plt.xlabel('Predicted label', fontsize=18)
plt.tight_layout()
class_names = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']
# Model
if opt.model == 'VGG19':
net = VGG('VGG19')
elif opt.model == 'Resnet18':
net = ResNet18()
path = os.path.join(opt.dataset + '_' + opt.model)
checkpoint = torch.load(os.path.join(path, opt.split + '_model.t7'))
net.load_state_dict(checkpoint['net'])
net.cuda()
net.eval()
Testset = FER2013(split = opt.split, transform=transform_test)
Testloader = torch.utils.data.DataLoader(Testset, batch_size=128, shuffle=False, num_workers=1)
correct = 0
total = 0
all_target = []
for batch_idx, (inputs, targets) in enumerate(Testloader):
bs, ncrops, c, h, w = np.shape(inputs)
inputs = inputs.view(-1, c, h, w)
inputs, targets = inputs.cuda(), targets.cuda()
inputs, targets = Variable(inputs, volatile=True), Variable(targets)
outputs = net(inputs)
outputs_avg = outputs.view(bs, ncrops, -1).mean(1) # avg over crops
_, predicted = torch.max(outputs_avg.data, 1)
total += targets.size(0)
correct += predicted.eq(targets.data).cpu().sum()
if batch_idx == 0:
all_predicted = predicted
all_targets = targets
else:
all_predicted = torch.cat((all_predicted, predicted),0)
all_targets = torch.cat((all_targets, targets),0)
acc = 100. * correct / total
print("accuracy: %0.3f" % acc)
# Compute confusion matrix
matrix = confusion_matrix(all_targets.data.cpu().numpy(), all_predicted.cpu().numpy())
np.set_printoptions(precision=2)
# Plot normalized confusion matrix
plt.figure(figsize=(10, 8))
plot_confusion_matrix(matrix, classes=class_names, normalize=True,
title= opt.split+' Confusion Matrix (Accuracy: %0.3f%%)' %acc)
plt.savefig(os.path.join(path, opt.split + '_cm.png'))
plt.close()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/shenhao2954988368/Facial-Expression-Recognition.Pytorch.git
git@gitee.com:shenhao2954988368/Facial-Expression-Recognition.Pytorch.git
shenhao2954988368
Facial-Expression-Recognition.Pytorch
Facial-Expression-Recognition.Pytorch
master

搜索帮助

344bd9b3 5694891 D2dac590 5694891