Skip to content

Instantly share code, notes, and snippets.

@dersmon
Created November 11, 2015 13:03
Show Gist options
  • Save dersmon/c4a2605006e48ce99819 to your computer and use it in GitHub Desktop.
Save dersmon/c4a2605006e48ce99819 to your computer and use it in GitHub Desktop.
Using a trained caffe model (buggy).
import sys
import caffe
import cv2
import Image
import numpy as np
from scipy.misc import imresize
caffe_root = "/home/simon/caffe/"
#MODEL_FILE = caffe_root + 'models/placesCNN/places205CNN_deploy.prototxt'
#PRETRAINED = caffe_root + 'models/placesCNN/places205CNN_iter_300000.caffemodel'
MODEL_FILE = caffe_root + 'models/customAlex/deploy.prototxt'
PRETRAINED = caffe_root + 'models/customAlex/caffe_alexnet_train_iter_1500.caffemodel'
net = caffe.Net(MODEL_FILE, PRETRAINED, caffe.TEST)
caffe.set_mode_cpu()
img = caffe.io.load_image(caffe_root + 'examples/images/library.jpg')
img = imresize(img, [227, 227])
img = img.astype(np.uint8)
out = net.forward_all(data=np.asarray([img.transpose(2, 0, 1)]))
print("Predicted class is #{}.".format(out['prob'][0].argmax()))
#imagenet_labels_filename = caffe_root + 'models/placesCNN/categoryIndex_places205.csv'
imagenet_labels_filename = caffe_root + 'models/customAlex/indexLabelMapping.txt'
labels = np.loadtxt(imagenet_labels_filename, str, delimiter='\s')
top_k = net.blobs['prob'].data[0].flatten().argsort()[-1: -6: -1]
print out['prob'][0]
print top_k
print labels[top_k]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment