Skip to content

Instantly share code, notes, and snippets.

@david90
Created February 14, 2017 09:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save david90/e98e1c41a0ebc580e5a9ce25ff6a972d to your computer and use it in GitHub Desktop.
Save david90/e98e1c41a0ebc580e5a9ce25ff6a972d to your computer and use it in GitHub Desktop.
Code for extracting inception bottleneck feature
import os
import tensorflow as tf
import tensorflow.python.platform
from tensorflow.python.platform import gfile
import numpy as np
def create_graph(model_path):
"""
create_graph loads the inception model to memory, should be called before
calling extract_features.
model_path: path to inception model in protobuf form.
"""
with gfile.FastGFile(model_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
def extract_features(image_paotths, verbose=False):
"""
extract_features computed the inception bottleneck feature for a list of images
image_paths: array of image path
return: 2-d array in the shape of (len(image_paths), 2048)
"""
feature_dimension = 2048
features = np.empty((len(image_paths), feature_dimension))
with tf.Session() as sess:
flattened_tensor = sess.graph.get_tensor_by_name('pool_3:0')
for i, image_path in enumerate(image_paths):
if verbose:
print('Processing %s...' % (image_path))
if not gfile.Exists(image_path):
tf.logging.fatal('File does not exist %s', image)
image_data = gfile.FastGFile(image_path, 'rb').read()
feature = sess.run(flattened_tensor, {
'DecodeJpeg/contents:0': image_data
})
features[i, :] = np.squeeze(feature)
return features
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment