Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
classify images using caffe's bvlc_reference_caffenet model
#!/usr/bin/python
# modified from https://nbviewer.jupyter.org/github/BVLC/caffe/blob/master/examples/00-classification.ipynb
import numpy as np
import argparse, sys, caffe, os, pprint
parser = argparse.ArgumentParser(
description='classify images using bvlc_reference_caffenet',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-c', '--caffe', type=str,
default=os.environ['CAFFE_ROOT'], help='root directory of caffe')
parser.add_argument('-v', '--verbose', type=int, default=1, help='verbosity')
parser.add_argument('-t', '--top', type=int, default=5, help='top-most how many guesses')
parser.add_argument('-w', '--width', type=int, default=100, help='line width of pretty print')
parser.add_argument('-m', '--mean', type=str,
default='python/caffe/imagenet/ilsvrc_2012_mean.npy', help='mean image npy file')
parser.add_argument('image_files', nargs='*')
args = parser.parse_args()
caffe.set_mode_cpu()
caffe_root = args.caffe + '/'
model_def = caffe_root + 'models/bvlc_reference_caffenet/deploy.prototxt'
model_weights = caffe_root + 'models/bvlc_reference_caffenet/bvlc_reference_caffenet.caffemodel'
net = caffe.Net(model_def, model_weights, caffe.TEST)
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
transformer.set_transpose('data', (2,0,1))
if args.mean:
mean_file = args.mean if args.mean[0]=='/' else caffe_root + args.mean
transformer.set_mean('data', np.load(mean_file).mean(1).mean(1))
transformer.set_raw_scale('data', 255)
transformer.set_channel_swap('data', (2,1,0))
net.blobs['data'].reshape(10, 3, 227, 227)
images = [caffe.io.load_image(img_f) for img_f in args.image_files]
transformed_images = [transformer.preprocess('data', img) for img in images]
labels_file = caffe_root + 'data/ilsvrc12/synset_words.txt'
labels = np.loadtxt(labels_file, str, delimiter='\t')
for i in range(len(args.image_files)):
net.blobs['data'].data[i,...] = transformed_images[i]
output = net.forward()
top_guesses = [prob.argsort()[::-1][:args.top] for prob in output['prob']]
print
pp = pprint.PrettyPrinter(indent=4, width=args.width)
for i in range(len(args.image_files)):
(fn, img, t_img, top, prob) = (args.image_files[i], images[i],
transformed_images[i], top_guesses[i], output['prob'][i][top_guesses[i]])
if args.verbose > 0:
print fn
if args.verbose > 1:
print '# {} => {}'.format(img.shape, t_img.shape)
tmp = [c for x in img for y in x for c in y]
print '# hist of orig {}'.format(np.histogram(tmp, bins=5))
tmp = [c for x in t_img for y in x for c in y]
print '# hist of trans {}'.format(np.histogram(tmp, bins=5))
pp.pprint(zip(labels[top], prob))
print
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment