Created
February 6, 2019 11:56
-
-
Save sheerun/4af949802bc245bae64e8a21a3f8e708 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import requests | |
import responder | |
import os, tempfile | |
# download("http://gajamed.pl/uploads/2016/04/29.jpg") | |
# => "/tmp/foo.jpg" | |
def download(url): | |
dest = os.path.join(tempfile.mkdtemp(), 'image.jpg') | |
response = requests.get(url, allow_redirects=True) | |
with open(dest, 'wb') as f: | |
f.write(response.content) | |
return dest | |
# load_graph("output_graph.pb") | |
# => tensorflow graph | |
def load_graph(filename): | |
graph = tf.Graph() | |
graph_def = tf.GraphDef() | |
with open(filename, "rb") as f: | |
graph_def.ParseFromString(f.read()) | |
with graph.as_default(): | |
tf.import_graph_def(graph_def) | |
return graph | |
# read_file("/tmp/foo.jpg") | |
# => tensorflow representation of image (299 x 299 px) | |
def read_image(file_name): | |
height, width = [299, 299] | |
file_reader = tf.read_file(file_name, "file_reader") | |
image_reader = tf.image.decode_jpeg(file_reader, channels=3, name="jpeg_reader") | |
float_caster = tf.cast(image_reader, tf.float32) | |
dims_expander = tf.expand_dims(float_caster, 0) | |
resized = tf.image.resize_bilinear(dims_expander, [height, width]) | |
normalized = tf.divide(tf.subtract(resized, [0]), [255]) | |
sess = tf.Session() | |
result = sess.run(normalized) | |
return result | |
# classify_image("http://gajamed.pl/uploads/2016/04/29.jpg) | |
# => [("daisy", 0.2), ("rose", 0.9)] | |
def classify_image(url): | |
labels = map(str.strip, open("output_labels.txt").readlines()) | |
graph = load_graph("output_graph.pb") | |
input = graph.get_operation_by_name("import/Placeholder").outputs[0] | |
output = graph.get_operation_by_name("import/final_result").outputs[0] | |
filepath = download(url) | |
image = read_image(filepath) | |
with tf.Session(graph=graph) as sess: | |
results = sess.run(output, { input: image })[0] | |
return list(zip(labels, results)) | |
api = responder.API() | |
# curl http://localhost:5432/http://gajamed.pl/uploads/2016/04/29.jpg | |
# rose: 90% | |
# daisy: 20% | |
@api.route(None, default=True) | |
def handler(req, resp): | |
image_url = req.url.path[1:] | |
text = "" | |
for label, confidence in classify_image(image_url): | |
text += label + ": " + str(int(confidence * 100)) + "%\n" | |
resp.text = text | |
# You can run this script by first activating anaconda | |
# and then running python index.py | |
# Api should be available at the address visible in terminal | |
if __name__ == '__main__': | |
api.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment