Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sexypreneur/3a2a4fcc03081a2095a159a9fdf96836 to your computer and use it in GitHub Desktop.
Save sexypreneur/3a2a4fcc03081a2095a159a9fdf96836 to your computer and use it in GitHub Desktop.
Caffe confusion matrix, precision and recall and F1 Score script!
#in the name of God the most compassionate the most merciful
# added mean subtraction so that, the accuracy can be reported accurately just like caffe when doing a mean subtraction
# Seyyed Hossein Hasan Pour
# Coderx7@Gmail.com
# 7/3/2016
# Added Recall/Precision/F1-Score as well
# 01/03/2017
# Added batch processing, not what used to take a minute or so, takes only several seconds!
# 07/25/2017
#info:
#if on windows, one can use these command in a batch file and ease him/her self
#REM Calculating Confusing Matrix
#python confusionMatrix_convnet_test.py --proto cifar10_deploy.prototxt --model cifar10.caffemodel --mean mean.binaryproto --lmdb cifar10_test_lmdb
#pause
import sys
import caffe
import numpy as np
import lmdb
import argparse
from collections import defaultdict
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import itertools
from sklearn.metrics import roc_curve, auc
import random
def flat_shape(x):
"Returns x without singleton dimension, eg: (1,28,28) -> (28,28)"
return np.reshape(x,x.shape)
def plot_confusion_matrix(cm #confusion matrix
,classes
,normalize=False
,title='Confusion matrix'
,cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("confusion matrix is normalized!")
#print(cm)
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, cm[i, j],
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
def db_reader(fpath, type='lmdb'):
if type == 'lmdb':
return lmdb_reader(fpath)
else:
return leveldb_reader(fpath)
def lmdb_reader(fpath):
import lmdb
lmdb_env = lmdb.open(fpath)
lmdb_txn = lmdb_env.begin()
lmdb_cursor = lmdb_txn.cursor()
for key, value in lmdb_cursor:
datum = caffe.proto.caffe_pb2.Datum()
datum.ParseFromString(value)
label = int(datum.label)
image = caffe.io.datum_to_array(datum).astype(np.uint8)
yield (key, flat_shape(image), label)
def leveldb_reader(fpath):
import leveldb
db = leveldb.LevelDB(fpath)
for key, value in db.RangeIter():
datum = caffe.proto.caffe_pb2.Datum()
datum.ParseFromString(value)
label = int(datum.label)
image = caffe.io.datum_to_array(datum).astype(np.uint8)
yield (key, flat_shape(image), label)
def ShowInfo(correct, count, true_labels, predicted_lables, class_names, misclassified,
filename='misclassifieds.txt',
title='Receiver Operating Characteristic_ROC',
title_CM='Confusion matrix, without normalization',
title_CM_N='Normalized confusion matrix'):
sys.stdout.write("\rAccuracy: %.1f%%" % (100.*correct/count))
sys.stdout.flush()
print(", %i/%i corrects" % (correct, count))
np.savetxt(filename,misclassified,fmt="%s")
print( classification_report(y_true=true_labels,
y_pred=predicted_lables,
target_names=class_names))
cm = confusion_matrix(y_true=true_labels,
y_pred=predicted_lables)
print(cm)
# print(title)
# false_positive_rate, true_positive_rate, thresholds = roc_curve(true_labels, predicted_lables)
# roc_auc = auc(false_positive_rate, true_positive_rate)
# plt.title('Receiver Operating Characteristic_ROC 1')
# plt.plot(false_positive_rate, true_positive_rate, 'b',
# label='AUC = %0.2f'% roc_auc)
# plt.legend(loc='lower right')
# plt.plot([0,1],[0,1],'r--')
# plt.xlim([-0.1,1.2])
# plt.ylim([-0.1,1.2])
# plt.ylabel('True Positive Rate')
# plt.xlabel('False Positive Rate')
# plt.show()
# Compute confusion matrix
cnf_matrix = cm
np.set_printoptions(precision=2)
# Plot non-normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names,
title=title_CM)
# Plot normalized confusion matrix
plt.figure()
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
title=title_CM_N)
plt.show()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--proto', help='path to the network prototxt file(deploy)', type=str, required=True)
parser.add_argument('--model', help='path to your caffemodel file', type=str, required=True)
parser.add_argument('--mean', help='path to the mean file(.binaryproto)', type=str, required=True)
#group = parser.add_mutually_exclusive_group(required=True)
parser.add_argument('--db_type', help='lmdb or leveldb', type=str, required=True)
parser.add_argument('--db_path', help='path to your lmdb/leveldb dataset', type=str, required=True)
args = parser.parse_args()
# Extract mean from the mean image file
mean_blobproto_new = caffe.proto.caffe_pb2.BlobProto()
f = open(args.mean, 'rb')
mean_blobproto_new.ParseFromString(f.read())
mean_image = caffe.io.blobproto_to_array(mean_blobproto_new)
f.close()
#mu = np.load('mean.npy')
#mu = np.array([ 104, 117, 123])#imagenet mean
caffe.set_mode_gpu()
#CNN reconstruction and loading the trained weights
#print ("args", vars(args))
predicted_lables=[]
true_labels = []
misclassified =[]
#class_names = ['unsafe','safe']
class_names = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
count=0
correct = 0
idx=0
batch=[]
plabe_ls=[]
batch_size = 50
net1 = caffe.Net(args.proto, args.model, caffe.TEST)
transformer = caffe.io.Transformer({'data': net1.blobs['data'].data.shape})
#transformer.set_transpose('data', (2,0,1))
transformer.set_mean('data', mean_image[0])
#transformer.set_raw_scale('data', 1)
transformer.set_channel_swap('data', (2,1,0))
net1.blobs['data'].reshape(batch_size, 3,32, 32)
data_blob_shape = net1.blobs['data'].data.shape
data_blob_shape = list(data_blob_shape)
#net1.blobs['data'].reshape(batch_size, data_blob_shape[1], data_blob_shape[2], data_blob_shape[3])
i=0
#check and see if its lmdb or leveldb
if(args.db_type.lower() == 'lmdb'):
lmdb_env = lmdb.open(args.db_path)
lmdb_txn = lmdb_env.begin()
lmdb_cursor = lmdb_txn.cursor()
for key, value in lmdb_cursor:
count += 1
datum = caffe.proto.caffe_pb2.Datum()
datum.ParseFromString(value)
label = int(datum.label)
image = caffe.io.datum_to_array(datum).astype(np.uint8)
#key,image,label
#buffer n image
if(count%2000==0):
print('count: ',count)
if(i < batch_size):
i+=1
inf= key,image,label
batch.append(inf)
#print(key)
if(i >= batch_size):
#process n image
ims=[]
images = [image_info for image_info in batch ]
for x in range(len(batch)):
#using transformer here decreases performance!
ims.append(batch[x][1]-mean_image[0]) #ims.append(transformer.preprocess('data',batch[x][1] ))
net1.blobs['data'].data[...] = ims[:]
out_1 = net1.forward()
plabe_ls = out_1['prob']#.argmax(axis=0)
plbl = np.asarray(plabe_ls)
#print(plbl)
#print(plbl.argmax(axis=1))
plbl = plbl.argmax(axis=1)
for j in range(len(batch)):
if (plbl[j] == batch[j][2]):
correct+=1
else:
misclassified.append(batch[j][0])
predicted_lables.append(plbl[j])
true_labels.append(batch[j][2])
batch.clear()
i=0
ShowInfo(correct,count, true_labels, predicted_lables, class_names, misclassified,
filename='misclassifieds.txt',
title='Receiver Operating Characteristic_ROC' )
else:#leveldb
import leveldb
db = leveldb.LevelDB(args.db_path)
for key, value in db.RangeIter():
datum = caffe.proto.caffe_pb2.Datum()
datum.ParseFromString(value)
label = int(datum.label)
image = caffe.io.datum_to_array(datum).astype(np.uint8)
#key,image,label
#buffer n image
#print('count: ',count)
if(i < batch_size):
i+=1
inf= key,image,label
batch.append(inf)
#print(key)
if(i >= batch_size):
#process n image
ims=[]
images = [image_info for image_info in batch ]
for x in range(len(batch)):
ims.append(batch[x][1]-mean_image[0]) #ims.append(transformer.preprocess('data',batch[x][1]))
net1.blobs['data'].data[...] = ims[:]
out_1 = net1.forward()
plabe_ls = out_1['prob']#.argmax(axis=0)
plbl = np.asarray(plabe_ls)
#print(plbl)
#print(plbl.argmax(axis=1))
plbl = plbl.argmax(axis=1)
for j in range(len(batch)):
if (plbl[j] == batch[j][2]):
correct+=1
else:
misclassified.append(batch[j][0])
predicted_lables.append(plbl[j])
true_labels.append(batch[j][2])
batch.clear()
i=0
ShowInfo(correct,count, true_labels, predicted_lables, class_names, misclassified,
filename='misclassifieds.txt',
title='Receiver Operating Characteristic_ROC' )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment