Created
December 3, 2019 23:02
-
-
Save zeryx/5068fc21fb5385d3428c49879910aca3 to your computer and use it in GitHub Desktop.
image classification with recursion example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
import Algorithmia | |
import os | |
import re | |
import numpy as np | |
TEMP_COLLECTION = 'data://.session/' | |
MODEL_FILE = "data://zeryx/InceptionNetDemo/classify_image_graph_def.pb" | |
CONVERSION_FILE = "data://zeryx/InceptionNetDemo/imagenet_synset_to_human_label_map.txt" | |
LABEL_FILE = "data://zeryx/InceptionNetDemo/imagenet_2012_challenge_label_map_proto.pbtxt" | |
# Model loading, handled in global state to ensure easy processing | |
def load_model(client): | |
path_to_labels = client.file(LABEL_FILE).getFile().name | |
path_to_model = client.file(MODEL_FILE).getFile().name | |
path_to_conversion = client.file(CONVERSION_FILE).getFile().name | |
detection_graph = tf.Graph() | |
with detection_graph.as_default(): | |
graph_def = tf.GraphDef() | |
with tf.gfile.GFile(path_to_model, 'rb') as fid: | |
serialized_graph = fid.read() | |
graph_def.ParseFromString(serialized_graph) | |
tf.import_graph_def(graph_def, name='') | |
label_index = load_label_index(path_to_conversion, path_to_labels) | |
return detection_graph, label_index | |
def load_label_index(conversion_path, label_path): | |
with open(conversion_path) as f: | |
proto_as_ascii_lines = f.read().split('\n')[:-1] | |
uid_to_human = {} | |
p = re.compile(r'[n\d]*[ \S,]*') | |
for line in proto_as_ascii_lines: | |
parsed_items = p.findall(line) | |
uid = parsed_items[0] | |
human_string = parsed_items[2] | |
uid_to_human[uid] = human_string | |
node_id_to_uid = {} | |
proto_as_ascii = tf.gfile.GFile(label_path).readlines() | |
for line in proto_as_ascii: | |
if line.startswith(' target_class:'): | |
target_class = int(line.split(': ')[1]) | |
if line.startswith(' target_class_string:'): | |
target_class_string = line.split(': ')[1] | |
node_id_to_uid[target_class] = target_class_string[1:-2] | |
# Loads the final mapping of integer node ID to human-readable string | |
node_id_to_name = {} | |
for key, val in node_id_to_uid.items(): | |
if val not in uid_to_human: | |
tf.logging.fatal('Failed to locate: %s', val) | |
name = uid_to_human[val] | |
node_id_to_name[key] = name | |
return node_id_to_name | |
def id_to_string(index_file, node_id): | |
if node_id not in index_file: | |
return '' | |
return index_file[node_id] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import tensorflow as tf | |
from multiprocessing import Manager, Pool | |
from .auxillary import load_model, id_to_string | |
import Algorithmia | |
import os | |
import re | |
# This is code for most tensorflow image classification algorithms | |
# In this example we look at solving batch processing problems with algorithm recursion and pipelining. | |
client = Algorithmia.client() | |
graph, label_index = load_model(client) | |
#-- IMPORANT --# be aware of the algorithm version you're calling, as this is self-referential while you're doing development you may need to replace this variable with a version hash. | |
# TODO: We'll improve this experience in the future | |
THIS_ALGO = "zeryx/recursive_image_example/0.1.x" | |
# The number of recursive requests that will be open at any time, this keeps us from overwelming the development environment by constraining our resources to some reasonable maximum. | |
NUM_PARALLEL_REQUESTS = 10 | |
# The maximum amount of work each algorithm will request will handle before recursing | |
MAX_WORK_PER_REQUEST = 8 | |
def get_image(url): | |
"""Uses the Smart Image Downloader algorithm to format and download images from the web or other places.""" | |
input = {'image': str(url)} | |
output_url = client.algo("util/SmartImageDownloader/0.2.x").pipe(input).result['savePath'][0] | |
temp_file = client.file(output_url).getFile().name | |
os.rename(temp_file, temp_file + '.' + output_url.split('.')[-1]) | |
return temp_file + '.' + output_url.split('.')[-1] | |
def inference(image): | |
image_data = tf.gfile.FastGFile(image, 'rb').read() | |
with tf.Session(graph=graph) as sess: | |
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0') | |
predictions = sess.run(softmax_tensor, | |
{'DecodeJpeg/contents:0': image_data}) | |
predictions = np.squeeze(predictions) | |
tags = [] | |
top_k = predictions.argsort()[-5:][::-1] | |
for node_id in top_k: | |
human_string = id_to_string(label_index, node_id) | |
score = predictions[node_id] | |
result = {} | |
result['class'] = human_string | |
result['confidence'] = score.item() | |
tags.append(result) | |
results = {} | |
results['tags'] = tags | |
return results | |
def algorithm_recursion_(input, errorQ): | |
"""This function will create a threadpool and make parallel calls to _algo, and return a callback. | |
As you can see, we limit the pool size by some value to ensure we don't overload anthing. | |
Besides that, we also blend the errorQ object into each chunk of data that we're passing into _algo. | |
If desired, a pool.starmap() can simplify this process.""" | |
pool = Pool(NUM_PARALLEL_REQUESTS) | |
chunks = _chunks(input, MAX_WORK_PER_REQUEST) | |
process_data = [(chunk, errorQ) for chunk in chunks] | |
async_ops = pool.starmap_async(_algo, process_data) | |
return async_ops | |
def _chunks(l, n): | |
"""Yield successive n-sized chunks from l.""" | |
for i in range(0, len(l), n): | |
yield l[i:i + n] | |
def _algo(algo_data, errorQ): | |
"""The primary working algorithm for our parallel threads, makes parallel requests and checks if errors exist""" | |
try: | |
if errorQ.empty(): | |
print("processing chunk..") | |
response = client.algo(THIS_ALGO).pipe(algo_data).result | |
print("finished chunk..") | |
return response | |
else: | |
return None | |
except Exception as e: | |
errorQ.put(e) | |
def batch_apply(input): | |
"""Simple sequential small batch processing, can be made parallel if necessary""" | |
results = [] | |
for image in input: | |
results.append(apply(image)) | |
return results | |
def apply(input): | |
if isinstance(input, str): | |
image = get_image(input) | |
results = {"image": input, "results": inference(image)} | |
elif isinstance(input, dict) and "image" in input: | |
image = get_image(input['image']) | |
results = {"image": input['image'], "results": inference(image)} | |
elif isinstance(input, list): | |
# If we do have a small list, it doesn't make sense to send off each request to a different machine, | |
# it might just be easier to process it here. | |
if len(input) < MAX_WORK_PER_REQUEST: | |
results = batch_apply(input) | |
else: | |
# Lets take some work for this algorithm to work on, before we pass the remainder to our recursively | |
# called algorithms | |
input_for_this_worker = input[:MAX_WORK_PER_REQUEST] | |
remaining_work = input[MAX_WORK_PER_REQUEST:] | |
# This object allows us to pass error messages and exceptions between threads, which can be very useful | |
# when things don't go as planned | |
manager = Manager() | |
errorQ = manager.Queue() | |
# We spin off the recursive / threading components of the algorithm to a separate thread so that we can | |
# process this algorithm's batch of work concurrently | |
eventual_remote_results = algorithm_recursion_(remaining_work, errorQ) | |
local_results = batch_apply(input_for_this_worker) | |
concurrent_results = eventual_remote_results.get() | |
# Make sure to check your error Q before returning a result, if it has errors we should return them as | |
# something went wrong | |
if errorQ.empty(): | |
results = local_results + concurrent_results | |
else: | |
raise errorQ.get() | |
else: | |
raise Exception("Input format invalid") | |
return results |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment