Created
February 25, 2017 22:56
-
-
Save ericterpstra/6a343e8252d2604e0568f53c5bbaf0b7 to your computer and use it in GitHub Desktop.
Lola Detector
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf, sys | |
import simplejson as json | |
from flask import Flask, jsonify | |
app = Flask(__name__) | |
def start_tf_session(): | |
# Loads label file, strips off carriage return | |
label_lines = [line.rstrip() for line | |
in tf.gfile.GFile("tfmodel/retrained_labels_cats.txt")] | |
# Unpersists graph from file | |
with tf.gfile.FastGFile("tfmodel/retrained_graph_cats.pb", 'rb') as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
_ = tf.import_graph_def(graph_def, name='') | |
print('Tensorflow session (for cats) created.') | |
return tf.Session() | |
# Create the "Tensor" | |
def get_tensor(sess): | |
print('softmax_tensor (for cats) got gotten.') | |
return sess.graph.get_tensor_by_name('final_result:0') | |
# This function gets called by the web service. It runs the Tensor model. | |
def image_labels(image_name): | |
# Load the image from disk | |
print('loading image') | |
image_data = tf.gfile.FastGFile(image_name, 'rb').read() | |
# Use TensorFlow! | |
print('running prediction') | |
predictions = sess.run(softmax_tensor, \ | |
{'DecodeJpeg/contents:0': image_data}) | |
print('sorting and organizing results') | |
# Sort to show labels of first prediction in order of confidence | |
top_k = predictions[0].argsort()[-len(predictions[0]):][::-1] | |
results = dict() | |
# Save the results from the TensorFlow function as a Dictionary | |
for node_id in top_k: | |
human_string = label_lines[node_id] | |
score = predictions[0][node_id] | |
print('%s (score = %.5f)' % (human_string, score)) | |
results[human_string] = str(round(score,5)) | |
return results | |
# Start the app! | |
print('Starting Tensorflow (for cats)') | |
# Start a TensorFlow session, and save the session as a variable | |
sess = start_tf_session() | |
# Get a reference to the softmax tensor with the newly created Session | |
softmax_tensor = get_tensor(sess) | |
# Load the labels file from the custom trained model | |
label_lines = [line.rstrip() for line in tf.gfile.GFile("tfmodel/retrained_labels_cats.txt")] | |
# Create a web service to analyze the raspberry pi image | |
@app.route("/lola") | |
def lola(): | |
print('Analyzing image for Lola.') | |
# Get the labels associated with the input image | |
labels = image_labels("rpicam.jpg") | |
# Return a JSON object of labels and probabilities | |
print('returning json') | |
return jsonify(**labels) | |
# Run Flask web service | |
if __name__ == "__main__": | |
app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment