Created
December 3, 2019 22:56
-
-
Save zeryx/88aad1f558cc23e438876091d2c626c3 to your computer and use it in GitHub Desktop.
image classification with asyncronous processing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Algorithmia | |
from .auxillary import load_model, id_to_string | |
import tensorflow as tf | |
import asyncio | |
import os, re | |
import numpy as np | |
# Global variables here | |
client = Algorithmia.client() | |
graph, label_index = load_model(client) | |
def get_image(url): | |
"""Uses the Smart Image Downloader algorithm to format and download images from the web or other places.""" | |
input = {'image': str(url)} | |
output_url = client.algo("util/SmartImageDownloader/0.2.x").pipe(input).result['savePath'][0] | |
temp_file = client.file(output_url).getFile().name | |
os.rename(temp_file, temp_file + '.' + output_url.split('.')[-1]) | |
return temp_file + '.' + output_url.split('.')[-1] | |
def do_work(image): | |
"""Does some computer vision work and needs a numpy array to function""" | |
image_data = tf.gfile.FastGFile(image, 'rb').read() | |
with tf.Session(graph=graph) as sess: | |
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0') | |
predictions = sess.run(softmax_tensor, | |
{'DecodeJpeg/contents:0': image_data}) | |
predictions = np.squeeze(predictions) | |
tags = [] | |
top_k = predictions.argsort()[-5:][::-1] | |
for node_id in top_k: | |
human_string = id_to_string(label_index, node_id) | |
score = predictions[node_id] | |
result = {} | |
result['class'] = human_string | |
result['confidence'] = score.item() | |
tags.append(result) | |
results = {} | |
results['tags'] = tags | |
return results | |
# We've added a processor function that gets and processes an image, but is prefixed with an 'async' | |
# We did this, as when dealing with batch for image processing algorithms, it's common that bottleneck is http and getting | |
# the images from a remote resource into your system. | |
# You can read more about 'asyncio' here: https://docs.python.org/3/library/asyncio.html | |
# Bare in mind that if you're using a version of python < 3.5, you'll need to import it as a pypi package. | |
async def process_url(url): | |
image_data = get_image(url) | |
result = do_work(image_data) | |
return result | |
def apply(input): | |
loop = asyncio.get_event_loop() | |
# We have a list of inputs that we're going to want to loop over | |
if isinstance(input, list): | |
future_images = [] | |
for url in input: | |
async_image = asyncio.ensure_future(process_url(url)) | |
future_images.append(async_image) | |
# Now we have a list of promises, let's loop through them until there's nothing left | |
results = loop.run_until_complete(asyncio.gather(*future_images)) | |
return results | |
elif isinstance(input, str): | |
# And if we are only processing one image at a time, lets keep the old functionality as well | |
image_data = get_image(input) | |
result = do_work(image_data) | |
return result | |
else: | |
raise Exception("Invalid input, expecting a list of Urls or a single URL string.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import Algorithmia | |
import os | |
import re | |
import numpy as np | |
TEMP_COLLECTION = 'data://.session/' | |
MODEL_FILE = "data://zeryx/InceptionNetDemo/classify_image_graph_def.pb" | |
CONVERSION_FILE = "data://zeryx/InceptionNetDemo/imagenet_synset_to_human_label_map.txt" | |
LABEL_FILE = "data://zeryx/InceptionNetDemo/imagenet_2012_challenge_label_map_proto.pbtxt" | |
# Model loading, handled in global state to ensure easy processing | |
def load_model(client): | |
path_to_labels = client.file(LABEL_FILE).getFile().name | |
path_to_model = client.file(MODEL_FILE).getFile().name | |
path_to_conversion = client.file(CONVERSION_FILE).getFile().name | |
detection_graph = tf.Graph() | |
with detection_graph.as_default(): | |
graph_def = tf.GraphDef() | |
with tf.gfile.GFile(path_to_model, 'rb') as fid: | |
serialized_graph = fid.read() | |
graph_def.ParseFromString(serialized_graph) | |
tf.import_graph_def(graph_def, name='') | |
label_index = load_label_index(path_to_conversion, path_to_labels) | |
return detection_graph, label_index | |
def load_label_index(conversion_path, label_path): | |
with open(conversion_path) as f: | |
proto_as_ascii_lines = f.read().split('\n')[:-1] | |
uid_to_human = {} | |
p = re.compile(r'[n\d]*[ \S,]*') | |
for line in proto_as_ascii_lines: | |
parsed_items = p.findall(line) | |
uid = parsed_items[0] | |
human_string = parsed_items[2] | |
uid_to_human[uid] = human_string | |
node_id_to_uid = {} | |
proto_as_ascii = tf.gfile.GFile(label_path).readlines() | |
for line in proto_as_ascii: | |
if line.startswith(' target_class:'): | |
target_class = int(line.split(': ')[1]) | |
if line.startswith(' target_class_string:'): | |
target_class_string = line.split(': ')[1] | |
node_id_to_uid[target_class] = target_class_string[1:-2] | |
# Loads the final mapping of integer node ID to human-readable string | |
node_id_to_name = {} | |
for key, val in node_id_to_uid.items(): | |
if val not in uid_to_human: | |
tf.logging.fatal('Failed to locate: %s', val) | |
name = uid_to_human[val] | |
node_id_to_name[key] = name | |
return node_id_to_name | |
def id_to_string(index_file, node_id): | |
if node_id not in index_file: | |
return '' | |
return index_file[node_id] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment