Created
March 13, 2020 19:33
-
-
Save zeryx/fc32fb143cd8740189c18046cb4cbc4f to your computer and use it in GitHub Desktop.
Example of how to use concurrency to process inputs and work jobs in paralle for batch processes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
algorithmia>=1.0.0,<2.0 | |
six | |
tensorflow-gpu==1.2.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
import Algorithmia | |
import os | |
import re | |
from multiprocessing import Queue | |
from threading import Thread | |
from time import sleep | |
# This is code for most tensorflow image classification algorithms | |
# In this example it's tuned specifically for our open images data example. | |
## - ORIGINAL TENSORFLOW IMPLEMENTATION -- ## | |
client = Algorithmia.client() | |
TEMP_COLLECTION = 'data://.session/' | |
SIMD_ALGO = "util/SmartImageDownloader/0.2.14" | |
MODEL_FILE = "data://zeryx/InceptionNetDemo/classify_image_graph_def.pb" | |
CONVERSION_FILE = "data://zeryx/InceptionNetDemo/imagenet_synset_to_human_label_map.txt" | |
LABEL_FILE = "data://zeryx/InceptionNetDemo/imagenet_2012_challenge_label_map_proto.pbtxt" | |
class AlgorithmError(Exception): | |
def __init__(self, value): | |
self.value = value | |
def __str__(self): | |
return repr(self.value) | |
## Replace with your own loading functions | |
def load_model(): | |
path_to_labels = client.file(LABEL_FILE).getFile().name | |
path_to_model = client.file(MODEL_FILE).getFile().name | |
path_to_conversion = client.file(CONVERSION_FILE).getFile().name | |
detection_graph = tf.Graph() | |
with detection_graph.as_default(): | |
graph_def = tf.GraphDef() | |
with tf.gfile.GFile(path_to_model, 'rb') as fid: | |
serialized_graph = fid.read() | |
graph_def.ParseFromString(serialized_graph) | |
tf.import_graph_def(graph_def, name='') | |
label_index = load_label_index(path_to_conversion, path_to_labels) | |
return detection_graph, label_index | |
def load_label_index(conversion_path, label_path): | |
with open(conversion_path) as f: | |
proto_as_ascii_lines = f.read().split('\n')[:-1] | |
uid_to_human = {} | |
p = re.compile(r'[n\d]*[ \S,]*') | |
for line in proto_as_ascii_lines: | |
parsed_items = p.findall(line) | |
uid = parsed_items[0] | |
human_string = parsed_items[2] | |
uid_to_human[uid] = human_string | |
node_id_to_uid = {} | |
proto_as_ascii = tf.gfile.GFile(label_path).readlines() | |
for line in proto_as_ascii: | |
if line.startswith(' target_class:'): | |
target_class = int(line.split(': ')[1]) | |
if line.startswith(' target_class_string:'): | |
target_class_string = line.split(': ')[1] | |
node_id_to_uid[target_class] = target_class_string[1:-2] | |
# Loads the final mapping of integer node ID to human-readable string | |
node_id_to_name = {} | |
for key, val in node_id_to_uid.items(): | |
if val not in uid_to_human: | |
tf.logging.fatal('Failed to locate: %s', val) | |
name = uid_to_human[val] | |
node_id_to_name[key] = name | |
return node_id_to_name | |
def id_to_string(index_file, node_id): | |
if node_id not in index_file: | |
return '' | |
return index_file[node_id] | |
def generate_gpu_config(memory_fraction): | |
config = tf.ConfigProto() | |
# config.gpu_options.allow_growth = True | |
config.gpu_options.per_process_gpu_memory_fraction = memory_fraction | |
return config | |
def get_image(url): | |
output_url = client.algo(SIMD_ALGO).pipe({'image': str(url)}).result['savePath'][0] | |
temp_file = client.file(output_url).getFile().name | |
os.rename(temp_file, temp_file + '.' + output_url.split('.')[-1]) | |
return temp_file + '.' + output_url.split('.')[-1] | |
## Replace this with your own inference loop | |
def inference(image): | |
image_data = tf.gfile.FastGFile(image, 'rb').read() | |
with tf.Session(graph=graph, config=generate_gpu_config(0.6)) as sess: | |
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0') | |
predictions = sess.run(softmax_tensor, | |
{'DecodeJpeg/contents:0': image_data}) | |
predictions = np.squeeze(predictions) | |
tags = [] | |
top_k = predictions.argsort()[-5:][::-1] | |
for node_id in top_k: | |
human_string = id_to_string(label_index, node_id) | |
score = predictions[node_id] | |
result = {} | |
result['class'] = human_string | |
result['confidence'] = np.asscalar(score) | |
tags.append(result) | |
results = {} | |
results['tags'] = tags | |
return results | |
### - END OF ORIGINAL TENSORFLOW IMPLEMENTATION -- ### | |
### - PARALLEL PROCESSING - ### | |
### This function checks to see if an error is in the queue, if not we iterate over all urls provided, download them and push to image_queue. | |
# If we run into any kind of error during downloading, we'll throw an exception into the error_queue, which will terminate both threads and join the process. | |
def get_image_parallel(urls, image_queue, error_queue): | |
for url in urls: | |
if not error_queue.empty(): | |
break | |
else: | |
try: | |
local_file = get_image(url) | |
image_queue.put((local_file, url)) | |
except Exception as e: | |
txt = "failed to download file {}".format(url) | |
error_queue.put((txt, e)) | |
### This function processes inference jobs unless we run out of work, or we get an error. | |
# We wait until data is put into the image_queue, and after we get the next item. | |
# We then run inference on the local image, and return the results in a tuple, along with the original url. | |
# If a failure happens, we make sure that we put an error message in the error_queue so that we can terminate the thread quickly. | |
def inference_parallel(batch_size, image_queue, error_queue, output): | |
for _ in range(batch_size): | |
while image_queue.empty() and error_queue.empty(): | |
sleep(0.25) | |
if not error_queue.empty: | |
break | |
else: | |
local_file, url = image_queue.get() | |
try: | |
result = inference(local_file) | |
output.append((result, url)) | |
except Exception as e: | |
txt = "failed to process image".format(url) | |
error_queue.put((txt, e)) | |
### Master orchestrator that spins up two threads, one for downloading and one for inference - and launchs them both. | |
# The downloader thread will download all images provided, and for each it'll push to a multithreading queue that will then be read by the inference thread | |
# Work will get finished by the work thread, and then pushed to a regular list object, which | |
def process_parallel(input): | |
batch_size = len(input) | |
image_queue = Queue(batch_size) | |
error_queue = Queue() | |
output_list = [] | |
downloader_t = Thread(target=get_image_parallel, args=(input, image_queue, error_queue)) | |
inference_t = Thread(target = inference_parallel, args=(batch_size, image_queue, error_queue, output_list)) | |
downloader_t.start() | |
inference_t.start() | |
downloader_t.join() | |
inference_t.join() | |
# if we find an error, lets make sure that we surface it | |
if not error_queue.empty(): | |
message, exc = error_queue.get() | |
result = ": ".join([message,str(exc)]) | |
raise AlgorithmError(result) | |
# otherwise, let's format the output data to the right style, and send it to the user. | |
for i in range(len(output_list)): | |
output_list[i] = {"result": output_list[i][0], "url": output_list[i][1]} | |
return output_list | |
## Notice the switch at the bottom, if we aren't presented with a list - the overhead of the queue and thread logic is unnecessary and we can skip that. | |
def apply(input): | |
if isinstance(input, str): | |
image = get_image(input) | |
results = inference(image) | |
elif isinstance(input, dict) and "image" in input: | |
image = get_image(input['image']) | |
results = inference(image) | |
elif isinstance(input, list): | |
results = process_parallel(input) | |
else: | |
raise AlgorithmError("input format invalid") | |
return results | |
graph, label_index = load_model() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment