Skip to content

Instantly share code, notes, and snippets.

@dspmeng
Last active January 8, 2018 05:26
Show Gist options
  • Save dspmeng/c3b48e5d722a68b376b497c5c31a3226 to your computer and use it in GitHub Desktop.
Save dspmeng/c3b48e5d722a68b376b497c5c31a3226 to your computer and use it in GitHub Desktop.
caffe classification sample code
import numpy as np
import matplotlib.pyplot as plt
import argparse, os, sys
caffe_root = '/path/to/caffe/'
sys.path.insert(0, caffe_root + 'python')
import caffe
def predict(args):
caffe.set_device(0)
caffe.set_mode_gpu()
net = caffe.Net(args.prototxt, args.weights, caffe.TEST)
# transform image loaded by caffe.io.load_image to caffe format (ocv style)
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
transformer.set_transpose('data', (2,0,1)) # interleave to planar
transformer.set_mean('data', np.array([104, 117, 123]))
transformer.set_raw_scale('data', 225) # [0, 1] to [0, 255]
transformer.set_channel_swap('data', (2,1,0)) # RGB to BGR
image = caffe.io.load_image(args.image)
resized_image = caffe.io.resize_image(image, (448,448))
transformed_image = transformer.preprocess('data', resized_image)
net.blobs['data'].data[...] = transformed_image
probs = net.forward()['prob'][0]
result = '%s Class: %d' % (args.image, probs.argmax())
print result
plt.title(result)
plt.imshow(image)
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('prototxt', help='model deploy prototxt')
parser.add_argument('weights', help='model weights')
parser.add_argument('image', help='image to predict')
args = parser.parse_args()
predict(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment