Skip to content

Instantly share code, notes, and snippets.

@sheerun
Created February 6, 2019 11:56
Show Gist options
  • Save sheerun/4af949802bc245bae64e8a21a3f8e708 to your computer and use it in GitHub Desktop.
Save sheerun/4af949802bc245bae64e8a21a3f8e708 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import requests
import responder
import os, tempfile
# download("http://gajamed.pl/uploads/2016/04/29.jpg")
# => "/tmp/foo.jpg"
def download(url):
dest = os.path.join(tempfile.mkdtemp(), 'image.jpg')
response = requests.get(url, allow_redirects=True)
with open(dest, 'wb') as f:
f.write(response.content)
return dest
# load_graph("output_graph.pb")
# => tensorflow graph
def load_graph(filename):
graph = tf.Graph()
graph_def = tf.GraphDef()
with open(filename, "rb") as f:
graph_def.ParseFromString(f.read())
with graph.as_default():
tf.import_graph_def(graph_def)
return graph
# read_file("/tmp/foo.jpg")
# => tensorflow representation of image (299 x 299 px)
def read_image(file_name):
height, width = [299, 299]
file_reader = tf.read_file(file_name, "file_reader")
image_reader = tf.image.decode_jpeg(file_reader, channels=3, name="jpeg_reader")
float_caster = tf.cast(image_reader, tf.float32)
dims_expander = tf.expand_dims(float_caster, 0)
resized = tf.image.resize_bilinear(dims_expander, [height, width])
normalized = tf.divide(tf.subtract(resized, [0]), [255])
sess = tf.Session()
result = sess.run(normalized)
return result
# classify_image("http://gajamed.pl/uploads/2016/04/29.jpg)
# => [("daisy", 0.2), ("rose", 0.9)]
def classify_image(url):
labels = map(str.strip, open("output_labels.txt").readlines())
graph = load_graph("output_graph.pb")
input = graph.get_operation_by_name("import/Placeholder").outputs[0]
output = graph.get_operation_by_name("import/final_result").outputs[0]
filepath = download(url)
image = read_image(filepath)
with tf.Session(graph=graph) as sess:
results = sess.run(output, { input: image })[0]
return list(zip(labels, results))
api = responder.API()
# curl http://localhost:5432/http://gajamed.pl/uploads/2016/04/29.jpg
# rose: 90%
# daisy: 20%
@api.route(None, default=True)
def handler(req, resp):
image_url = req.url.path[1:]
text = ""
for label, confidence in classify_image(image_url):
text += label + ": " + str(int(confidence * 100)) + "%\n"
resp.text = text
# You can run this script by first activating anaconda
# and then running python index.py
# Api should be available at the address visible in terminal
if __name__ == '__main__':
api.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment