Skip to content

Instantly share code, notes, and snippets.

@zeryx
Created March 13, 2020 19:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zeryx/fc32fb143cd8740189c18046cb4cbc4f to your computer and use it in GitHub Desktop.
Save zeryx/fc32fb143cd8740189c18046cb4cbc4f to your computer and use it in GitHub Desktop.
Example of how to use concurrency to process inputs and work jobs in paralle for batch processes
algorithmia>=1.0.0,<2.0
six
tensorflow-gpu==1.2.0
import numpy as np
import tensorflow as tf
import Algorithmia
import os
import re
from multiprocessing import Queue
from threading import Thread
from time import sleep
# This is code for most tensorflow image classification algorithms
# In this example it's tuned specifically for our open images data example.
## - ORIGINAL TENSORFLOW IMPLEMENTATION -- ##
client = Algorithmia.client()
TEMP_COLLECTION = 'data://.session/'
SIMD_ALGO = "util/SmartImageDownloader/0.2.14"
MODEL_FILE = "data://zeryx/InceptionNetDemo/classify_image_graph_def.pb"
CONVERSION_FILE = "data://zeryx/InceptionNetDemo/imagenet_synset_to_human_label_map.txt"
LABEL_FILE = "data://zeryx/InceptionNetDemo/imagenet_2012_challenge_label_map_proto.pbtxt"
class AlgorithmError(Exception):
def __init__(self, value):
self.value = value
def __str__(self):
return repr(self.value)
## Replace with your own loading functions
def load_model():
path_to_labels = client.file(LABEL_FILE).getFile().name
path_to_model = client.file(MODEL_FILE).getFile().name
path_to_conversion = client.file(CONVERSION_FILE).getFile().name
detection_graph = tf.Graph()
with detection_graph.as_default():
graph_def = tf.GraphDef()
with tf.gfile.GFile(path_to_model, 'rb') as fid:
serialized_graph = fid.read()
graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(graph_def, name='')
label_index = load_label_index(path_to_conversion, path_to_labels)
return detection_graph, label_index
def load_label_index(conversion_path, label_path):
with open(conversion_path) as f:
proto_as_ascii_lines = f.read().split('\n')[:-1]
uid_to_human = {}
p = re.compile(r'[n\d]*[ \S,]*')
for line in proto_as_ascii_lines:
parsed_items = p.findall(line)
uid = parsed_items[0]
human_string = parsed_items[2]
uid_to_human[uid] = human_string
node_id_to_uid = {}
proto_as_ascii = tf.gfile.GFile(label_path).readlines()
for line in proto_as_ascii:
if line.startswith(' target_class:'):
target_class = int(line.split(': ')[1])
if line.startswith(' target_class_string:'):
target_class_string = line.split(': ')[1]
node_id_to_uid[target_class] = target_class_string[1:-2]
# Loads the final mapping of integer node ID to human-readable string
node_id_to_name = {}
for key, val in node_id_to_uid.items():
if val not in uid_to_human:
tf.logging.fatal('Failed to locate: %s', val)
name = uid_to_human[val]
node_id_to_name[key] = name
return node_id_to_name
def id_to_string(index_file, node_id):
if node_id not in index_file:
return ''
return index_file[node_id]
def generate_gpu_config(memory_fraction):
config = tf.ConfigProto()
# config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = memory_fraction
return config
def get_image(url):
output_url = client.algo(SIMD_ALGO).pipe({'image': str(url)}).result['savePath'][0]
temp_file = client.file(output_url).getFile().name
os.rename(temp_file, temp_file + '.' + output_url.split('.')[-1])
return temp_file + '.' + output_url.split('.')[-1]
## Replace this with your own inference loop
def inference(image):
image_data = tf.gfile.FastGFile(image, 'rb').read()
with tf.Session(graph=graph, config=generate_gpu_config(0.6)) as sess:
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
predictions = sess.run(softmax_tensor,
{'DecodeJpeg/contents:0': image_data})
predictions = np.squeeze(predictions)
tags = []
top_k = predictions.argsort()[-5:][::-1]
for node_id in top_k:
human_string = id_to_string(label_index, node_id)
score = predictions[node_id]
result = {}
result['class'] = human_string
result['confidence'] = np.asscalar(score)
tags.append(result)
results = {}
results['tags'] = tags
return results
### - END OF ORIGINAL TENSORFLOW IMPLEMENTATION -- ###
### - PARALLEL PROCESSING - ###
### This function checks to see if an error is in the queue, if not we iterate over all urls provided, download them and push to image_queue.
# If we run into any kind of error during downloading, we'll throw an exception into the error_queue, which will terminate both threads and join the process.
def get_image_parallel(urls, image_queue, error_queue):
for url in urls:
if not error_queue.empty():
break
else:
try:
local_file = get_image(url)
image_queue.put((local_file, url))
except Exception as e:
txt = "failed to download file {}".format(url)
error_queue.put((txt, e))
### This function processes inference jobs unless we run out of work, or we get an error.
# We wait until data is put into the image_queue, and after we get the next item.
# We then run inference on the local image, and return the results in a tuple, along with the original url.
# If a failure happens, we make sure that we put an error message in the error_queue so that we can terminate the thread quickly.
def inference_parallel(batch_size, image_queue, error_queue, output):
for _ in range(batch_size):
while image_queue.empty() and error_queue.empty():
sleep(0.25)
if not error_queue.empty:
break
else:
local_file, url = image_queue.get()
try:
result = inference(local_file)
output.append((result, url))
except Exception as e:
txt = "failed to process image".format(url)
error_queue.put((txt, e))
### Master orchestrator that spins up two threads, one for downloading and one for inference - and launchs them both.
# The downloader thread will download all images provided, and for each it'll push to a multithreading queue that will then be read by the inference thread
# Work will get finished by the work thread, and then pushed to a regular list object, which
def process_parallel(input):
batch_size = len(input)
image_queue = Queue(batch_size)
error_queue = Queue()
output_list = []
downloader_t = Thread(target=get_image_parallel, args=(input, image_queue, error_queue))
inference_t = Thread(target = inference_parallel, args=(batch_size, image_queue, error_queue, output_list))
downloader_t.start()
inference_t.start()
downloader_t.join()
inference_t.join()
# if we find an error, lets make sure that we surface it
if not error_queue.empty():
message, exc = error_queue.get()
result = ": ".join([message,str(exc)])
raise AlgorithmError(result)
# otherwise, let's format the output data to the right style, and send it to the user.
for i in range(len(output_list)):
output_list[i] = {"result": output_list[i][0], "url": output_list[i][1]}
return output_list
## Notice the switch at the bottom, if we aren't presented with a list - the overhead of the queue and thread logic is unnecessary and we can skip that.
def apply(input):
if isinstance(input, str):
image = get_image(input)
results = inference(image)
elif isinstance(input, dict) and "image" in input:
image = get_image(input['image'])
results = inference(image)
elif isinstance(input, list):
results = process_parallel(input)
else:
raise AlgorithmError("input format invalid")
return results
graph, label_index = load_model()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment