Skip to content

Instantly share code, notes, and snippets.

Created March 13, 2020 19:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zeryx/fc32fb143cd8740189c18046cb4cbc4f to your computer and use it in GitHub Desktop.
Save zeryx/fc32fb143cd8740189c18046cb4cbc4f to your computer and use it in GitHub Desktop.
Example of how to use concurrency to process inputs and work jobs in paralle for batch processes
import numpy as np
import tensorflow as tf
import Algorithmia
import os
import re
from multiprocessing import Queue
from threading import Thread
from time import sleep
# This is code for most tensorflow image classification algorithms
# In this example it's tuned specifically for our open images data example.
client = Algorithmia.client()
TEMP_COLLECTION = 'data://.session/'
SIMD_ALGO = "util/SmartImageDownloader/0.2.14"
MODEL_FILE = "data://zeryx/InceptionNetDemo/classify_image_graph_def.pb"
CONVERSION_FILE = "data://zeryx/InceptionNetDemo/imagenet_synset_to_human_label_map.txt"
LABEL_FILE = "data://zeryx/InceptionNetDemo/imagenet_2012_challenge_label_map_proto.pbtxt"
class AlgorithmError(Exception):
def __init__(self, value):
self.value = value
def __str__(self):
return repr(self.value)
## Replace with your own loading functions
def load_model():
path_to_labels = client.file(LABEL_FILE).getFile().name
path_to_model = client.file(MODEL_FILE).getFile().name
path_to_conversion = client.file(CONVERSION_FILE).getFile().name
detection_graph = tf.Graph()
with detection_graph.as_default():
graph_def = tf.GraphDef()
with tf.gfile.GFile(path_to_model, 'rb') as fid:
serialized_graph =
tf.import_graph_def(graph_def, name='')
label_index = load_label_index(path_to_conversion, path_to_labels)
return detection_graph, label_index
def load_label_index(conversion_path, label_path):
with open(conversion_path) as f:
proto_as_ascii_lines ='\n')[:-1]
uid_to_human = {}
p = re.compile(r'[n\d]*[ \S,]*')
for line in proto_as_ascii_lines:
parsed_items = p.findall(line)
uid = parsed_items[0]
human_string = parsed_items[2]
uid_to_human[uid] = human_string
node_id_to_uid = {}
proto_as_ascii = tf.gfile.GFile(label_path).readlines()
for line in proto_as_ascii:
if line.startswith(' target_class:'):
target_class = int(line.split(': ')[1])
if line.startswith(' target_class_string:'):
target_class_string = line.split(': ')[1]
node_id_to_uid[target_class] = target_class_string[1:-2]
# Loads the final mapping of integer node ID to human-readable string
node_id_to_name = {}
for key, val in node_id_to_uid.items():
if val not in uid_to_human:
tf.logging.fatal('Failed to locate: %s', val)
name = uid_to_human[val]
node_id_to_name[key] = name
return node_id_to_name
def id_to_string(index_file, node_id):
if node_id not in index_file:
return ''
return index_file[node_id]
def generate_gpu_config(memory_fraction):
config = tf.ConfigProto()
# config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = memory_fraction
return config
def get_image(url):
output_url = client.algo(SIMD_ALGO).pipe({'image': str(url)}).result['savePath'][0]
temp_file = client.file(output_url).getFile().name
os.rename(temp_file, temp_file + '.' + output_url.split('.')[-1])
return temp_file + '.' + output_url.split('.')[-1]
## Replace this with your own inference loop
def inference(image):
image_data = tf.gfile.FastGFile(image, 'rb').read()
with tf.Session(graph=graph, config=generate_gpu_config(0.6)) as sess:
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
predictions =,
{'DecodeJpeg/contents:0': image_data})
predictions = np.squeeze(predictions)
tags = []
top_k = predictions.argsort()[-5:][::-1]
for node_id in top_k:
human_string = id_to_string(label_index, node_id)
score = predictions[node_id]
result = {}
result['class'] = human_string
result['confidence'] = np.asscalar(score)
results = {}
results['tags'] = tags
return results
### This function checks to see if an error is in the queue, if not we iterate over all urls provided, download them and push to image_queue.
# If we run into any kind of error during downloading, we'll throw an exception into the error_queue, which will terminate both threads and join the process.
def get_image_parallel(urls, image_queue, error_queue):
for url in urls:
if not error_queue.empty():
local_file = get_image(url)
image_queue.put((local_file, url))
except Exception as e:
txt = "failed to download file {}".format(url)
error_queue.put((txt, e))
### This function processes inference jobs unless we run out of work, or we get an error.
# We wait until data is put into the image_queue, and after we get the next item.
# We then run inference on the local image, and return the results in a tuple, along with the original url.
# If a failure happens, we make sure that we put an error message in the error_queue so that we can terminate the thread quickly.
def inference_parallel(batch_size, image_queue, error_queue, output):
for _ in range(batch_size):
while image_queue.empty() and error_queue.empty():
if not error_queue.empty:
local_file, url = image_queue.get()
result = inference(local_file)
output.append((result, url))
except Exception as e:
txt = "failed to process image".format(url)
error_queue.put((txt, e))
### Master orchestrator that spins up two threads, one for downloading and one for inference - and launchs them both.
# The downloader thread will download all images provided, and for each it'll push to a multithreading queue that will then be read by the inference thread
# Work will get finished by the work thread, and then pushed to a regular list object, which
def process_parallel(input):
batch_size = len(input)
image_queue = Queue(batch_size)
error_queue = Queue()
output_list = []
downloader_t = Thread(target=get_image_parallel, args=(input, image_queue, error_queue))
inference_t = Thread(target = inference_parallel, args=(batch_size, image_queue, error_queue, output_list))
# if we find an error, lets make sure that we surface it
if not error_queue.empty():
message, exc = error_queue.get()
result = ": ".join([message,str(exc)])
raise AlgorithmError(result)
# otherwise, let's format the output data to the right style, and send it to the user.
for i in range(len(output_list)):
output_list[i] = {"result": output_list[i][0], "url": output_list[i][1]}
return output_list
## Notice the switch at the bottom, if we aren't presented with a list - the overhead of the queue and thread logic is unnecessary and we can skip that.
def apply(input):
if isinstance(input, str):
image = get_image(input)
results = inference(image)
elif isinstance(input, dict) and "image" in input:
image = get_image(input['image'])
results = inference(image)
elif isinstance(input, list):
results = process_parallel(input)
raise AlgorithmError("input format invalid")
return results
graph, label_index = load_model()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment