Created
April 24, 2017 03:14
-
-
Save grapeot/25d6f15f1d5d548079bdf44622ce135c to your computer and use it in GitHub Desktop.
Use Caffe2 to extract features
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Initial imports | |
import os | |
import sys | |
import logging | |
logging.basicConfig(format='%(asctime)s : %(levelname)s : %(message)s', level=logging.INFO) | |
from caffe2.proto import caffe2_pb2 | |
import numpy as np | |
import skimage.io | |
import skimage.transform | |
from caffe2.python import core, workspace | |
import urllib2 | |
from caffe2.python.models import bvlc_reference_caffenet as mynet | |
logging.info('Required modules imported.') | |
# What model are we using? You should have already converted or downloaded one. | |
# format below is the model's: | |
# folder, INIT_NET, predict_net, mean, input image size | |
# you can switch the comments on MODEL to try out different model conversions | |
MODEL = 'bvlc_reference_caffenet', 'init_net.pb', 'predict_net.pb', 'ilsvrc_2012_mean.npy', 227 | |
# some models were trained with different image sizes, this helps you calibrate your image | |
INPUT_IMAGE_SIZE = MODEL[4] | |
# codes - these help decypher the output and source from a list from AlexNet's object codes to provide an result like "tabby cat" or "lemon" depending on what's in the picture you submit to the neural network. | |
# The list of output codes for the AlexNet models (also squeezenet) | |
codes = "https://gist.githubusercontent.com/aaronmarkham/cd3a6b6ac071eca6f7b4a6e40e6038aa/raw/9edb4038a37da6b5a44c3b5bc52e448ff09bfe5b/alexnet_codes" | |
mean = 0 | |
logging.info("Config set.") | |
# Set up some functions | |
def crop_center(img,cropx,cropy): | |
y,x,c = img.shape | |
startx = x//2-(cropx//2) | |
starty = y//2-(cropy//2) | |
return img[starty:starty+cropy,startx:startx+cropx] | |
def rescale(img, input_height, input_width): | |
aspect = img.shape[1]/float(img.shape[0]) | |
if(aspect>1): | |
# landscape orientation - wide image | |
res = int(aspect * input_height) | |
imgScaled = skimage.transform.resize(img, (input_width, res)) | |
if(aspect<1): | |
# portrait orientation - tall image | |
res = int(input_width/aspect) | |
imgScaled = skimage.transform.resize(img, (res, input_height)) | |
if(aspect == 1): | |
imgScaled = skimage.transform.resize(img, (input_width, input_height)) | |
return imgScaled | |
# Initialize Caffe2 and return the workspace object | |
def initCaffe2(): | |
# Configs | |
# where you installed caffe2. Probably '~/caffe2' or '~/src/caffe2'. | |
CAFFE2_ROOT = "/usr/local/caffe2" | |
# assumes being a subdirectory of caffe2 | |
CAFFE_MODELS = "/usr/local/caffe2/python/models" | |
# if you have a mean file, place it in the same dir as the model | |
# set paths and variables from model choice and prep image | |
CAFFE2_ROOT = os.path.expanduser(CAFFE2_ROOT) | |
CAFFE_MODELS = os.path.expanduser(CAFFE_MODELS) | |
# mean can be 128 or custom based on the model | |
# gives better results to remove the colors found in all of the training images | |
MEAN_FILE = os.path.join(CAFFE_MODELS, MODEL[0], MODEL[3]) | |
if not os.path.exists(MEAN_FILE): | |
mean = 128 | |
else: | |
mean = np.load(MEAN_FILE).mean(1).mean(1) | |
mean = mean[:, np.newaxis, np.newaxis] | |
logging.info("mean was set to: " + str(mean)) | |
# make sure all of the files are around... | |
if not os.path.exists(CAFFE2_ROOT): | |
logging.info("Houston, you may have a problem.") | |
INIT_NET = os.path.join(CAFFE_MODELS, MODEL[0], MODEL[1]) | |
logging.info('INIT_NET = ' + INIT_NET) | |
PREDICT_NET = os.path.join(CAFFE_MODELS, MODEL[0], MODEL[2]) | |
logging.info('PREDICT_NET = ' + PREDICT_NET) | |
if not os.path.exists(INIT_NET): | |
logging.info(INIT_NET + " not found!") | |
else: | |
logging.info("Found " + INIT_NET + "...Now looking for" + PREDICT_NET) | |
if not os.path.exists(PREDICT_NET): | |
logging.info("Caffe model file, " + PREDICT_NET + " was not found!") | |
else: | |
logging.info("All needed files found! Loading the model in the next block.") | |
# initialize the neural net | |
with open(INIT_NET) as f: | |
init_net = f.read() | |
with open(PREDICT_NET) as f: | |
predict_net = f.read() | |
p = workspace.Predictor(init_net, predict_net) | |
return p | |
def extractFeatures(predictor, imgfn): | |
# load and transform image | |
img = skimage.img_as_float(skimage.io.imread(imgfn)).astype(np.float32) | |
img = rescale(img, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE) | |
img = crop_center(img, INPUT_IMAGE_SIZE, INPUT_IMAGE_SIZE) | |
# switch to CHW | |
img = img.swapaxes(1, 2).swapaxes(0, 1) | |
# switch to BGR | |
img = img[(2, 1, 0), :, :] | |
# remove mean for better results | |
img = img * 255 - mean | |
# add batch size | |
img = img[np.newaxis, :, :, :].astype(np.float32) | |
# run the net and return prediction | |
results = predictor.run([img]) | |
# turn it into something we can play with and examine which is in a multi-dimensional array | |
results = np.asarray(results) | |
results = results[0,0,:] | |
return results | |
if __name__ == '__main__': | |
import argparse | |
parser = argparse.ArgumentParser(description='Extract features from an image or a list of images.') | |
parser.add_argument('--img', help='Path of the target image.') | |
parser.add_argument('--imglist', help='Path of the image list. One line per image.') | |
args = parser.parse_args() | |
if args.img is None and args.imglist is None: | |
logging.error('Neither img or imglist were specified. Exitting...') | |
sys.exit(-1) | |
if args.img is not None and args.imglist is not None: | |
logging.error('Both img or imglist were specified. Exitting...') | |
sys.exit(-1) | |
p = initCaffe2() | |
if args.img is not None: | |
features = extractFeatures(p, args.img) | |
logging.info('Processed {0}.'.format(args.img)) | |
print('{0}\t{1}'.format(args.img, ','.join([str(x) for x in features.tolist()]))) | |
elif args.imglist is not None: | |
imgs = [ x.strip() for x in open(args.imglist) ] | |
for img in imgs: | |
try: | |
features = extractFeatures(p, img) | |
logging.info('Processed {0}.'.format(img)) | |
print('{0}\t{1}'.format(img, ','.join([str(x) for x in features.tolist()]))) | |
except Exception as e: | |
logging.error('Exception during processing {0}: {1}'.format(img, e)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment