Last active
January 8, 2018 05:26
-
-
Save dspmeng/c3b48e5d722a68b376b497c5c31a3226 to your computer and use it in GitHub Desktop.
caffe classification sample code
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import matplotlib.pyplot as plt | |
import argparse, os, sys | |
caffe_root = '/path/to/caffe/' | |
sys.path.insert(0, caffe_root + 'python') | |
import caffe | |
def predict(args): | |
caffe.set_device(0) | |
caffe.set_mode_gpu() | |
net = caffe.Net(args.prototxt, args.weights, caffe.TEST) | |
# transform image loaded by caffe.io.load_image to caffe format (ocv style) | |
transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape}) | |
transformer.set_transpose('data', (2,0,1)) # interleave to planar | |
transformer.set_mean('data', np.array([104, 117, 123])) | |
transformer.set_raw_scale('data', 225) # [0, 1] to [0, 255] | |
transformer.set_channel_swap('data', (2,1,0)) # RGB to BGR | |
image = caffe.io.load_image(args.image) | |
resized_image = caffe.io.resize_image(image, (448,448)) | |
transformed_image = transformer.preprocess('data', resized_image) | |
net.blobs['data'].data[...] = transformed_image | |
probs = net.forward()['prob'][0] | |
result = '%s Class: %d' % (args.image, probs.argmax()) | |
print result | |
plt.title(result) | |
plt.imshow(image) | |
plt.show() | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('prototxt', help='model deploy prototxt') | |
parser.add_argument('weights', help='model weights') | |
parser.add_argument('image', help='image to predict') | |
args = parser.parse_args() | |
predict(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment