Skip to content

Instantly share code, notes, and snippets.

@ale64bit
Created October 25, 2016 19:44
Show Gist options
  • Save ale64bit/03e59ffcfd4a5aa9e001f4cd4f4f50b5 to your computer and use it in GitHub Desktop.
Save ale64bit/03e59ffcfd4a5aa9e001f4cd4f4f50b5 to your computer and use it in GitHub Desktop.
compute cosine similarity for Inception bottleneck output of two images
import sys
import os.path
import numpy as np
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import gfile
BOTTLENECK_TENSOR_NAME = 'pool_3/_reshape:0'
BOTTLENECK_TENSOR_SIZE = 2048
JPEG_DATA_TENSOR_NAME = 'DecodeJpeg/contents:0'
RESIZED_INPUT_TENSOR_NAME = 'ResizeBilinear:0'
def eval_image(sess, image_file, bottleneck_tensor, image_data_tensor):
image_data = gfile.FastGFile(image_file, 'rb').read()
bottleneck_values = sess.run(bottleneck_tensor, {image_data_tensor: image_data})
return np.squeeze(bottleneck_values)
def similarity(x, y):
return tf.reduce_sum(tf.mul(tf.nn.l2_normalize(x, 0), tf.nn.l2_normalize(y, 0)))
if len(sys.argv) != 4:
print(' Usage: eval <model_file> <image1> <image2>')
exit(-1)
model_file = sys.argv[1]
if not gfile.Exists(model_file):
tf.logging.fatal('File does not exist "%s"', model_file)
exit(-1)
img1 = sys.argv[2]
if not gfile.Exists(img1):
tf.logging.fatal('File does not exist "%s"', img1)
exit(-1)
img2 = sys.argv[3]
if not gfile.Exists(img2):
tf.logging.fatal('File does not exist "%s"', img2)
exit(-1)
print('Loading model...')
sess = tf.Session()
with gfile.FastGFile(model_file, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
bottleneck_tensor, jpeg_data_tensor, resized_input_tensor = (
tf.import_graph_def(graph_def, name='', return_elements=[
BOTTLENECK_TENSOR_NAME, JPEG_DATA_TENSOR_NAME, RESIZED_INPUT_TENSOR_NAME]))
print('Evaluating...')
bottleneck1 = eval_image(sess, img1, bottleneck_tensor, jpeg_data_tensor)
bottleneck2 = eval_image(sess, img2, bottleneck_tensor, jpeg_data_tensor)
print(sess.run(similarity(bottleneck1, bottleneck2)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment