import tensorflow as tf, sys | |
import simplejson as json | |
from flask import Flask, jsonify | |
app = Flask(__name__) | |
def start_tf_session(): | |
# Loads label file, strips off carriage return | |
label_lines = [line.rstrip() for line | |
in tf.gfile.GFile("tfmodel/retrained_labels_cats.txt")] | |
# Unpersists graph from file | |
with tf.gfile.FastGFile("tfmodel/retrained_graph_cats.pb", 'rb') as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
_ = tf.import_graph_def(graph_def, name='') | |
print('Tensorflow session (for cats) created.') | |
return tf.Session() | |
# Create the "Tensor" | |
def get_tensor(sess): | |
print('softmax_tensor (for cats) got gotten.') | |
return sess.graph.get_tensor_by_name('final_result:0') | |
# This function gets called by the web service. It runs the Tensor model. | |
def image_labels(image_name): | |
# Load the image from disk | |
print('loading image') | |
image_data = tf.gfile.FastGFile(image_name, 'rb').read() | |
# Use TensorFlow! | |
print('running prediction') | |
predictions = sess.run(softmax_tensor, \ | |
{'DecodeJpeg/contents:0': image_data}) | |
print('sorting and organizing results') | |
# Sort to show labels of first prediction in order of confidence | |
top_k = predictions[0].argsort()[-len(predictions[0]):][::-1] | |
results = dict() | |
# Save the results from the TensorFlow function as a Dictionary | |
for node_id in top_k: | |
human_string = label_lines[node_id] | |
score = predictions[0][node_id] | |
print('%s (score = %.5f)' % (human_string, score)) | |
results[human_string] = str(round(score,5)) | |
return results | |
# Start the app! | |
print('Starting Tensorflow (for cats)') | |
# Start a TensorFlow session, and save the session as a variable | |
sess = start_tf_session() | |
# Get a reference to the softmax tensor with the newly created Session | |
softmax_tensor = get_tensor(sess) | |
# Load the labels file from the custom trained model | |
label_lines = [line.rstrip() for line in tf.gfile.GFile("tfmodel/retrained_labels_cats.txt")] | |
# Create a web service to analyze the raspberry pi image | |
@app.route("/lola") | |
def lola(): | |
print('Analyzing image for Lola.') | |
# Get the labels associated with the input image | |
labels = image_labels("rpicam.jpg") | |
# Return a JSON object of labels and probabilities | |
print('returning json') | |
return jsonify(**labels) | |
# Run Flask web service | |
if __name__ == "__main__": | |
app.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment