Skip to content

Instantly share code, notes, and snippets.

@zeryx
Created December 3, 2019 22:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zeryx/88aad1f558cc23e438876091d2c626c3 to your computer and use it in GitHub Desktop.
Save zeryx/88aad1f558cc23e438876091d2c626c3 to your computer and use it in GitHub Desktop.
image classification with asyncronous processing
import Algorithmia
from .auxillary import load_model, id_to_string
import tensorflow as tf
import asyncio
import os, re
import numpy as np
# Global variables here
client = Algorithmia.client()
graph, label_index = load_model(client)
def get_image(url):
"""Uses the Smart Image Downloader algorithm to format and download images from the web or other places."""
input = {'image': str(url)}
output_url = client.algo("util/SmartImageDownloader/0.2.x").pipe(input).result['savePath'][0]
temp_file = client.file(output_url).getFile().name
os.rename(temp_file, temp_file + '.' + output_url.split('.')[-1])
return temp_file + '.' + output_url.split('.')[-1]
def do_work(image):
"""Does some computer vision work and needs a numpy array to function"""
image_data = tf.gfile.FastGFile(image, 'rb').read()
with tf.Session(graph=graph) as sess:
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
predictions = sess.run(softmax_tensor,
{'DecodeJpeg/contents:0': image_data})
predictions = np.squeeze(predictions)
tags = []
top_k = predictions.argsort()[-5:][::-1]
for node_id in top_k:
human_string = id_to_string(label_index, node_id)
score = predictions[node_id]
result = {}
result['class'] = human_string
result['confidence'] = score.item()
tags.append(result)
results = {}
results['tags'] = tags
return results
# We've added a processor function that gets and processes an image, but is prefixed with an 'async'
# We did this, as when dealing with batch for image processing algorithms, it's common that bottleneck is http and getting
# the images from a remote resource into your system.
# You can read more about 'asyncio' here: https://docs.python.org/3/library/asyncio.html
# Bare in mind that if you're using a version of python < 3.5, you'll need to import it as a pypi package.
async def process_url(url):
image_data = get_image(url)
result = do_work(image_data)
return result
def apply(input):
loop = asyncio.get_event_loop()
# We have a list of inputs that we're going to want to loop over
if isinstance(input, list):
future_images = []
for url in input:
async_image = asyncio.ensure_future(process_url(url))
future_images.append(async_image)
# Now we have a list of promises, let's loop through them until there's nothing left
results = loop.run_until_complete(asyncio.gather(*future_images))
return results
elif isinstance(input, str):
# And if we are only processing one image at a time, lets keep the old functionality as well
image_data = get_image(input)
result = do_work(image_data)
return result
else:
raise Exception("Invalid input, expecting a list of Urls or a single URL string.")
import tensorflow as tf
import Algorithmia
import os
import re
import numpy as np
TEMP_COLLECTION = 'data://.session/'
MODEL_FILE = "data://zeryx/InceptionNetDemo/classify_image_graph_def.pb"
CONVERSION_FILE = "data://zeryx/InceptionNetDemo/imagenet_synset_to_human_label_map.txt"
LABEL_FILE = "data://zeryx/InceptionNetDemo/imagenet_2012_challenge_label_map_proto.pbtxt"
# Model loading, handled in global state to ensure easy processing
def load_model(client):
path_to_labels = client.file(LABEL_FILE).getFile().name
path_to_model = client.file(MODEL_FILE).getFile().name
path_to_conversion = client.file(CONVERSION_FILE).getFile().name
detection_graph = tf.Graph()
with detection_graph.as_default():
graph_def = tf.GraphDef()
with tf.gfile.GFile(path_to_model, 'rb') as fid:
serialized_graph = fid.read()
graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(graph_def, name='')
label_index = load_label_index(path_to_conversion, path_to_labels)
return detection_graph, label_index
def load_label_index(conversion_path, label_path):
with open(conversion_path) as f:
proto_as_ascii_lines = f.read().split('\n')[:-1]
uid_to_human = {}
p = re.compile(r'[n\d]*[ \S,]*')
for line in proto_as_ascii_lines:
parsed_items = p.findall(line)
uid = parsed_items[0]
human_string = parsed_items[2]
uid_to_human[uid] = human_string
node_id_to_uid = {}
proto_as_ascii = tf.gfile.GFile(label_path).readlines()
for line in proto_as_ascii:
if line.startswith(' target_class:'):
target_class = int(line.split(': ')[1])
if line.startswith(' target_class_string:'):
target_class_string = line.split(': ')[1]
node_id_to_uid[target_class] = target_class_string[1:-2]
# Loads the final mapping of integer node ID to human-readable string
node_id_to_name = {}
for key, val in node_id_to_uid.items():
if val not in uid_to_human:
tf.logging.fatal('Failed to locate: %s', val)
name = uid_to_human[val]
node_id_to_name[key] = name
return node_id_to_name
def id_to_string(index_file, node_id):
if node_id not in index_file:
return ''
return index_file[node_id]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment