Skip to content

Instantly share code, notes, and snippets.

@jcjohnson
Created July 22, 2016 15:37
Show Gist options
  • Save jcjohnson/564c30b82e4211b917d800a1c34a6a22 to your computer and use it in GitHub Desktop.
Save jcjohnson/564c30b82e4211b917d800a1c34a6a22 to your computer and use it in GitHub Desktop.
import argparse, os, glob, tempfile
import h5py
import numpy as np
import matplotlib.pyplot as plt
from scipy.misc import imread, imresize
# Stupid workaround for some messed up images
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
import caffe
def write_temp_deploy(source_prototxt, batch_size):
"""
Modifies an existing prototxt by setting the batch size to a specific value.
A modified prototxt file is written as a temporary file.
Inputs:
- source_prototxt: Path to a deploy.prototxt that will be modified
- batch_size: Desired batch size for the network
Returns:
- path to the temporary file containing the modified prototxt
"""
_, target = tempfile.mkstemp()
with open(source_prototxt, 'r') as f:
lines = f.readlines()
found_batch_size_line = False
with open(target, 'w') as f:
for line in lines:
if line.startswith('input_dim:') and not found_batch_size_line:
found_batch_size_line = True
line = 'input_dim: %d\n' % batch_size
f.write(line)
return target
def resize_mean_image(mean_image, height, width):
"""
Resize the (ImageNet) mean image to a given size.
Inputs:
- mean_image: numpy float array of shape (3, H, W), in BGR order.
This is the format of the mean ImageNet image provided by Caffe.
- height, width: Desired height and width
Return:
A numpy float array of shape (3, height, width) in BGR order.
"""
mean_image_t = mean_image.transpose(1, 2, 0).astype('uint8')
mean_image_t_resized = imresize(mean_image_t, (height, width))
mean_image_resized = mean_image_t_resized.transpose(2, 0, 1).astype('float')
return mean_image_resized
def load_image(image_filename, height, width, mean_image):
"""
Read an image off disk and prepare it for caffe. We need to do the following:
(1) Resize to (height, width)
(2) Swap color channels from RGB to BGR
(3) Transpose from (H, W, C) to (C, H, W)
(4) Convert from uint8 to float
(5) Subtract mean image (which is already BGR)
Inputs:
- image_filename: Path to the image file to read
- height, width: Input size of the network; we'll reshape the image to this size
- mean_image: Numpy float array of shape (3, height, width) in BGR format giving
mean image to be subtracted.
"""
img = imread(image_filename)
try:
img = imresize(img, (height, width))
except ValueError as e:
print img.shape, image_filename
print 1/0
if img.ndim == 2:
# handle grayscale by adding an extra dim and replicating three times
img = img[:, :, None][:, :, [0, 0, 0]]
img = img[:, :, [2, 1, 0]].transpose(2, 1, 0).astype('float') - mean_image
return img
if __name__ == '__main__':
CAFFENET = '$CAFFE_ROOT/models/bvlc_reference_caffenet'
CAFFENET_DEPLOY = os.path.join(CAFFENET, 'deploy.prototxt')
CAFFENET_CAFFEMODEL = os.path.join(CAFFENET, 'bvlc_reference_caffenet.caffemodel')
parser = argparse.ArgumentParser()
parser.add_argument('--image_list', default='', required=True)
parser.add_argument('--deploy_txt', default=CAFFENET_DEPLOY)
parser.add_argument('--caffemodel', default=CAFFENET_CAFFEMODEL)
parser.add_argument('--mean_file',
default='$CAFFE_ROOT/python/caffe/imagenet/ilsvrc_2012_mean.npy')
parser.add_argument('--vgg_mean', action='store_true')
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--blob_name', default='fc7')
parser.add_argument('--batch_size', default=100, type=int)
parser.add_argument('--output_h5_file', default='features.h5')
args = parser.parse_args()
if args.gpu < 0:
caffe.set_mode_cpu()
else:
caffe.set_mode_gpu()
caffe.set_device(args.gpu)
deploy_file = os.path.expandvars(args.deploy_txt)
caffemodel_file = os.path.expandvars(args.caffemodel)
temp_deploy = write_temp_deploy(deploy_file, args.batch_size)
net = caffe.Net(temp_deploy, caffemodel_file, caffe.TEST)
net_height = net.blobs['data'].data.shape[2]
net_width = net.blobs['data'].data.shape[3]
# Read in image filenames from txt file
image_filenames = []
with open(args.image_list, 'r') as f:
for line in f:
image_filenames.append(line.strip())
print net.blobs[args.blob_name].data.shape
mean_image_file = os.path.expandvars(args.mean_file)
mean_image = np.load(mean_image_file)
# print 'mean image stats:'
# print mean_image.shape, mean_image.dtype
# print mean_image.min(), mean_image.max()
if args.vgg_mean:
print 'using vgg mean'
# VGG was trained by subtracting the mean pixel, not the mean image.
# The mean BGR pixel value is given at
# https://gist.github.com/ksimonyan/3785162f95cd2d5fee77
pixel = [103.939, 116.779, 123.68]
mean_image = np.asarray(pixel).reshape(3, 1, 1)
mean_image_resized = resize_mean_image(mean_image, net_height, net_width)
# print 'resized mean image stats:'
# print mean_image_resized.shape, mean_image_resized.dtype
# print mean_image_resized.min(), mean_image_resized.max()
# plt.imshow(mean_image.transpose(1,2,0)[:, :, [2,1,0]].astype('uint8'))
# plt.show()
#
# plt.imshow(mean_image_resized.transpose(1,2,0)[:, :, [2,1,0]].astype('uint8'))
# plt.show()
num_images = len(image_filenames)
h5_f = h5py.File(args.output_h5_file, 'w')
feature_shape = (num_images,) + net.blobs[args.blob_name].data.shape[1:]
dset = h5_f.create_dataset('features', feature_shape, dtype='f4')
dset.attrs['blob_name'] = args.blob_name
dset.attrs['deploy_txt'] = deploy_file
dset.attrs['caffemodel'] = caffemodel_file
dset.attrs['mean_file'] = mean_image_file
next_batch_idx = 0
next_dset_idx = 0
batch_data = np.zeros_like(net.blobs['data'].data)
for i, image_filename in enumerate(image_filenames):
img = load_image(image_filename, net_height, net_width, mean_image_resized)
batch_data[next_batch_idx] = img
next_batch_idx += 1
if next_batch_idx == args.batch_size:
net.forward(data=batch_data)
next_batch_idx = 0
dset[next_dset_idx:(next_dset_idx+args.batch_size)] = net.blobs[args.blob_name].data.copy()
next_dset_idx += args.batch_size
print 'done with %d / %d images' % (i + 1, num_images)
if next_batch_idx > 0:
net.forward(data=batch_data)
dset[next_dset_idx:] = net.blobs[args.blob_name].data[:next_batch_idx].copy()
h5_f.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment