Skip to content

Instantly share code, notes, and snippets.

@ericterpstra
Created February 25, 2017 22:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ericterpstra/6a343e8252d2604e0568f53c5bbaf0b7 to your computer and use it in GitHub Desktop.
Save ericterpstra/6a343e8252d2604e0568f53c5bbaf0b7 to your computer and use it in GitHub Desktop.
Lola Detector
import tensorflow as tf, sys
import simplejson as json
from flask import Flask, jsonify
app = Flask(__name__)
def start_tf_session():
# Loads label file, strips off carriage return
label_lines = [line.rstrip() for line
in tf.gfile.GFile("tfmodel/retrained_labels_cats.txt")]
# Unpersists graph from file
with tf.gfile.FastGFile("tfmodel/retrained_graph_cats.pb", 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
print('Tensorflow session (for cats) created.')
return tf.Session()
# Create the "Tensor"
def get_tensor(sess):
print('softmax_tensor (for cats) got gotten.')
return sess.graph.get_tensor_by_name('final_result:0')
# This function gets called by the web service. It runs the Tensor model.
def image_labels(image_name):
# Load the image from disk
print('loading image')
image_data = tf.gfile.FastGFile(image_name, 'rb').read()
# Use TensorFlow!
print('running prediction')
predictions = sess.run(softmax_tensor, \
{'DecodeJpeg/contents:0': image_data})
print('sorting and organizing results')
# Sort to show labels of first prediction in order of confidence
top_k = predictions[0].argsort()[-len(predictions[0]):][::-1]
results = dict()
# Save the results from the TensorFlow function as a Dictionary
for node_id in top_k:
human_string = label_lines[node_id]
score = predictions[0][node_id]
print('%s (score = %.5f)' % (human_string, score))
results[human_string] = str(round(score,5))
return results
# Start the app!
print('Starting Tensorflow (for cats)')
# Start a TensorFlow session, and save the session as a variable
sess = start_tf_session()
# Get a reference to the softmax tensor with the newly created Session
softmax_tensor = get_tensor(sess)
# Load the labels file from the custom trained model
label_lines = [line.rstrip() for line in tf.gfile.GFile("tfmodel/retrained_labels_cats.txt")]
# Create a web service to analyze the raspberry pi image
@app.route("/lola")
def lola():
print('Analyzing image for Lola.')
# Get the labels associated with the input image
labels = image_labels("rpicam.jpg")
# Return a JSON object of labels and probabilities
print('returning json')
return jsonify(**labels)
# Run Flask web service
if __name__ == "__main__":
app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment