Skip to content

Instantly share code, notes, and snippets.

@jamesmcintyre
Created November 22, 2018 07:17
Show Gist options
  • Save jamesmcintyre/ae293a329aaf531bd0d9211c012ffb53 to your computer and use it in GitHub Desktop.
Save jamesmcintyre/ae293a329aaf531bd0d9211c012ffb53 to your computer and use it in GitHub Desktop.
import tensorflow as tf, sys
image_path = sys.argv[1]
def read_tensor_from_image_file(file_name,
input_height=299,
input_width=299,
input_mean=0,
input_std=255):
input_name = "file_reader"
output_name = "normalized"
file_reader = tf.read_file(file_name, input_name)
if file_name.endswith(".png"):
image_reader = tf.image.decode_png(
file_reader, channels=3, name="png_reader")
elif file_name.endswith(".gif"):
image_reader = tf.squeeze(
tf.image.decode_gif(file_reader, name="gif_reader"))
elif file_name.endswith(".bmp"):
image_reader = tf.image.decode_bmp(file_reader, name="bmp_reader")
else:
image_reader = tf.image.decode_jpeg(
file_reader, channels=3, name="jpeg_reader")
float_caster = tf.cast(image_reader, tf.float32)
dims_expander = tf.expand_dims(float_caster, 0)
resized = tf.image.resize_bilinear(dims_expander, [input_height, input_width])
normalized = tf.divide(tf.subtract(resized, [input_mean]), [input_std])
sess = tf.Session()
result = sess.run(normalized)
return result
# Read in the image_data
# image_data = tf.gfile.FastGFile(image_path, 'rb').read()
image_data = read_tensor_from_image_file(image_path)
# Loads label file, strips off carriage return
label_lines = [line.rstrip() for line
in tf.gfile.GFile("/share_vol/retrained_labels.txt")]
# Unpersists graph from file
with tf.gfile.FastGFile("/share_vol/retrained_graph.pb", 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
# Feed the image_data as input to the graph and get first prediction
with tf.Session() as sess:
# tensor_name_list = [tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
# for tensor_name in tensor_name_list:
# print(tensor_name, '\n')
softmax_tensor = sess.graph.get_tensor_by_name('final_result:0')
predictions = sess.run(softmax_tensor,
{'Placeholder:0': image_data})
# Sort to show labels of first prediction in order of confidence
top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]
for node_id in top_k:
human_string = label_lines[node_id]
score = predictions[0][node_id]
print('%s (score = %.5f)' % (human_string, score))
@ThomasOrlita
Copy link

Works great! Just replace FastGFile with GFile.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment