Skip to content

Instantly share code, notes, and snippets.

Last active February 24, 2021 13:48
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save Coderx7/205651853a248a512256aa21f1d3bec0 to your computer and use it in GitHub Desktop.
Save Coderx7/205651853a248a512256aa21f1d3bec0 to your computer and use it in GitHub Desktop.
Confusion Matrix with Recall, Precision and F1-Score for Caffe
# Author: SeyyedHossein Hasanpour copyright 2017, license GPLv3.
# Seyyed Hossein Hasan Pour:
# Changelog:
# 2015:
# initial code to calculate confusionmatrix by Axel Angel
# 7/3/2016:(adding new features-by-hossein)
# added mean subtraction so that, the accuracy can be reported accurately just like caffe when doing a mean subtraction
# 01/03/2017:
# removed old codes and Added Recall/Precision/F1-Score as well
# 03/05/2017
# Added ConfusionMatrix which was mistakenly ommited before.
#if on windows, one can use these command in a batch file and ease him/her self
#REM Calculating Confusing Matrix
#python --proto cifar10_deploy_94_68.prototxt --model cifar10_deploy_94_68.caffemodel --mean mean.binaryproto --db_type lmdb --db_path cifar10_test_lmdb
import sys
import caffe
import numpy as np
import lmdb
import argparse
from collections import defaultdict
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import itertools
def flat_shape(x):
"Returns x without singleton dimension, eg: (1,28,28) -> (28,28)"
return x.reshape(filter(lambda s: s > 1, x.shape))
def db_reader(fpath, type='lmdb'):
if type == 'lmdb':
return lmdb_reader(fpath)
return leveldb_reader(fpath)
def plot_confusion_matrix(cm #confusion matrix
,title='Confusion matrix'
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
plt.imshow(cm, interpolation='nearest', cmap=cmap)
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
print("confusion matrix is normalized!")
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, cm[i, j],
color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
def lmdb_reader(fpath):
import lmdb
lmdb_env =
lmdb_txn = lmdb_env.begin()
lmdb_cursor = lmdb_txn.cursor()
for key, value in lmdb_cursor:
datum = caffe.proto.caffe_pb2.Datum()
label = int(datum.label)
image =
yield (key, flat_shape(image), label)
def leveldb_reader(fpath):
import leveldb
db = leveldb.LevelDB(fpath)
for key, value in db.RangeIter():
datum = caffe.proto.caffe_pb2.Datum()
label = int(datum.label)
image =
yield (key, flat_shape(image), label)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--proto', help='path to the network prototxt file(deploy)', type=str, required=True)
parser.add_argument('--model', help='path to your caffemodel file', type=str, required=True)
parser.add_argument('--mean', help='path to the mean file(.binaryproto)', type=str, required=True)
#group = parser.add_mutually_exclusive_group(required=True)
parser.add_argument('--db_type', help='lmdb or leveldb', type=str, required=True)
parser.add_argument('--db_path', help='path to your lmdb/leveldb dataset', type=str, required=True)
args = parser.parse_args()
# Extract mean from the mean image file
mean_blobproto_new = caffe.proto.caffe_pb2.BlobProto()
f = open(args.mean, 'rb')
mean_image =
# CNN reconstruction and loading the trained weights
net = caffe.Net(args.proto, args.model, caffe.TEST)
# You may also use set_mode_cpu() if you didnt compile caffe with gpu support
print ("args", vars(args))
reader = db_reader(args.db_path, args.db_type.lower())
true_labels = []
class_names = ['airplane','automobile','bird','cat','deer','dog','frog','horse','ship','truck']
for i, image, label in reader:
image_caffe = image.reshape(1, *image.shape)
#print 'image shape: ',image_caffe.shape
out = net.forward_all(data=np.asarray([ image_caffe ])- mean_image)
plabel = int(out['prob'][0].argmax(axis=0))
print(i,' processed!')
print( classification_report(y_true=true_labels,
cm = confusion_matrix(y_true=true_labels,
# Compute confusion matrix
cnf_matrix = cm
# Plot non-normalized confusion matrix
plot_confusion_matrix(cnf_matrix, classes=class_names,
title='Confusion matrix, without normalization')
# Plot normalized confusion matrix
plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
title='Normalized confusion matrix')
Copy link

flrndttrch commented May 19, 2017

It seems that your command instructions are not correct.

python --proto cifar10_deploy.prototxt --model cifar10_SimpleNet_xavier_95.26.caffemodel.h5 --mean mean.binaryproto --lmdb cifar10_test_lmdb

should be:

python --proto cifar10_deploy.prototxt --model cifar10_SimpleNet_xavier_95.26.caffemodel.h5 --mean mean.binaryproto --db_type lmdb --db_path cifar10_test_lmdb

Also I'm getting a KeyError for 'prob' in out:

Traceback (most recent call last): File "", line 137, in <module> plabel = int(out['prob'][0].argmax(axis=0)) KeyError: 'prob'

Maybe it has to be 'loss' and argmax()?

EDIT: Yes that seems to work fine
EDIT2: Ok got it. It depends on what you use in top:
layer { name: "loss" type: "Softmax" bottom: "ip2" top: "loss" }

Copy link

Coderx7 commented May 30, 2017

@flrndttrch: thanks for the note, when I updated the script I forgot to edit that part .
However I see I added the needed information as a comment to the script.
anyway I update the first post now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment