Skip to content

Instantly share code, notes, and snippets.

@stardust2602
Created July 14, 2015 03:08
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save stardust2602/79c818add4f7100397dd to your computer and use it in GitHub Desktop.
Save stardust2602/79c818add4f7100397dd to your computer and use it in GitHub Desktop.
Caffe script to compute accuracy and confusion matrix based on training input ( .txt file )
#!/usr/bin/python
# -*- coding: utf-8 -*-
import sys
import caffe
import numpy as np
import argparse
from collections import defaultdict
TRAIN_DATA_ROOT='/path/to/training/images/'
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--proto', type=str, required=True)
parser.add_argument('--model', type=str, required=True)
parser.add_argument('--meanfile', type=str, required=True)
parser.add_argument('--labelfile', type=str, required=True)
args = parser.parse_args()
proto_data = open(args.meanfile, 'rb').read()
a = caffe.io.caffe_pb2.BlobProto.FromString(proto_data)
mean = caffe.io.blobproto_to_array(a)[0]
net = caffe.Classifier(args.proto, args.model,
mean=mean,
channel_swap=(2,1,0),
raw_scale=255,
image_dims=(256, 256))
caffe.set_mode_gpu()
count = 0
correct = 0
matrix = defaultdict(int) # (real,pred) -> int
labels_set = set()
net = caffe.Net(args.proto, args.model, caffe.TEST)
f = open(args.labelfile, "r")
for line in f.readlines():
parts = line.split()
example_image = parts[0]
label = int(parts[1])
input_image = caffe.io.load_image(TRAIN_DATA_ROOT + example_image)
prediction = net.predict([input_image])
plabel = int(prediction[0].argmax())
count += 1
iscorrect = label == plabel
correct += (1 if iscorrect else 0)
matrix[(label, plabel)] += 1
labels_set.update([label, plabel])
if not iscorrect:
print("\rError: expected %i but predicted %i" \
% (label, plabel))
sys.stdout.write("\rAccuracy: %.1f%%" % (100.*correct/count))
sys.stdout.flush()
print(", %i/%i corrects" % (correct, count))
print ""
print "Confusion matrix:"
print "(r , p) | count"
for l in labels_set:
for pl in labels_set:
print "(%i , %i) | %i" % (l, pl, matrix[(l,pl)])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment