This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import requests | |
import responder | |
import os, tempfile | |
# download("http://gajamed.pl/uploads/2016/04/29.jpg") | |
# => "/tmp/foo.jpg" | |
def download(url): | |
dest = os.path.join(tempfile.mkdtemp(), 'image.jpg') | |
response = requests.get(url, allow_redirects=True) | |
with open(dest, 'wb') as f: | |
f.write(response.content) | |
return dest | |
# load_graph("output_graph.pb") | |
# => tensorflow graph | |
def load_graph(filename): | |
graph = tf.Graph() | |
graph_def = tf.GraphDef() | |
with open(filename, "rb") as f: | |
graph_def.ParseFromString(f.read()) | |
with graph.as_default(): | |
tf.import_graph_def(graph_def) | |
return graph | |
# read_file("/tmp/foo.jpg") | |
# => tensorflow representation of image (299 x 299 px) | |
def read_image(file_name): | |
height, width = [299, 299] | |
file_reader = tf.read_file(file_name, "file_reader") | |
image_reader = tf.image.decode_jpeg(file_reader, channels=3, name="jpeg_reader") | |
float_caster = tf.cast(image_reader, tf.float32) | |
dims_expander = tf.expand_dims(float_caster, 0) | |
resized = tf.image.resize_bilinear(dims_expander, [height, width]) | |
normalized = tf.divide(tf.subtract(resized, [0]), [255]) | |
sess = tf.Session() | |
result = sess.run(normalized) | |
return result | |
# classify_image("http://gajamed.pl/uploads/2016/04/29.jpg) | |
# => [("daisy", 0.2), ("rose", 0.9)] | |
def classify_image(url): | |
labels = map(str.strip, open("output_labels.txt").readlines()) | |
graph = load_graph("output_graph.pb") | |
input = graph.get_operation_by_name("import/Placeholder").outputs[0] | |
output = graph.get_operation_by_name("import/final_result").outputs[0] | |
filepath = download(url) | |
image = read_image(filepath) | |
with tf.Session(graph=graph) as sess: | |
results = sess.run(output, { input: image })[0] | |
return list(zip(labels, results)) | |
api = responder.API() | |
# curl http://localhost:5432/http://gajamed.pl/uploads/2016/04/29.jpg | |
# rose: 90% | |
# daisy: 20% | |
@api.route(None, default=True) | |
def handler(req, resp): | |
image_url = req.url.path[1:] | |
text = "" | |
for label, confidence in classify_image(image_url): | |
text += label + ": " + str(int(confidence * 100)) + "%\n" | |
resp.text = text | |
# You can run this script by first activating anaconda | |
# and then running python index.py | |
# Api should be available at the address visible in terminal | |
if __name__ == '__main__': | |
api.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment