Skip to content

Instantly share code, notes, and snippets.

@zeryx
Created December 3, 2019 23:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zeryx/5068fc21fb5385d3428c49879910aca3 to your computer and use it in GitHub Desktop.
Save zeryx/5068fc21fb5385d3428c49879910aca3 to your computer and use it in GitHub Desktop.
image classification with recursion example
import tensorflow as tf
import Algorithmia
import os
import re
import numpy as np
TEMP_COLLECTION = 'data://.session/'
MODEL_FILE = "data://zeryx/InceptionNetDemo/classify_image_graph_def.pb"
CONVERSION_FILE = "data://zeryx/InceptionNetDemo/imagenet_synset_to_human_label_map.txt"
LABEL_FILE = "data://zeryx/InceptionNetDemo/imagenet_2012_challenge_label_map_proto.pbtxt"
# Model loading, handled in global state to ensure easy processing
def load_model(client):
path_to_labels = client.file(LABEL_FILE).getFile().name
path_to_model = client.file(MODEL_FILE).getFile().name
path_to_conversion = client.file(CONVERSION_FILE).getFile().name
detection_graph = tf.Graph()
with detection_graph.as_default():
graph_def = tf.GraphDef()
with tf.gfile.GFile(path_to_model, 'rb') as fid:
serialized_graph = fid.read()
graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(graph_def, name='')
label_index = load_label_index(path_to_conversion, path_to_labels)
return detection_graph, label_index
def load_label_index(conversion_path, label_path):
with open(conversion_path) as f:
proto_as_ascii_lines = f.read().split('\n')[:-1]
uid_to_human = {}
p = re.compile(r'[n\d]*[ \S,]*')
for line in proto_as_ascii_lines:
parsed_items = p.findall(line)
uid = parsed_items[0]
human_string = parsed_items[2]
uid_to_human[uid] = human_string
node_id_to_uid = {}
proto_as_ascii = tf.gfile.GFile(label_path).readlines()
for line in proto_as_ascii:
if line.startswith(' target_class:'):
target_class = int(line.split(': ')[1])
if line.startswith(' target_class_string:'):
target_class_string = line.split(': ')[1]
node_id_to_uid[target_class] = target_class_string[1:-2]
# Loads the final mapping of integer node ID to human-readable string
node_id_to_name = {}
for key, val in node_id_to_uid.items():
if val not in uid_to_human:
tf.logging.fatal('Failed to locate: %s', val)
name = uid_to_human[val]
node_id_to_name[key] = name
return node_id_to_name
def id_to_string(index_file, node_id):
if node_id not in index_file:
return ''
return index_file[node_id]
import numpy as np
import tensorflow as tf
from multiprocessing import Manager, Pool
from .auxillary import load_model, id_to_string
import Algorithmia
import os
import re
# This is code for most tensorflow image classification algorithms
# In this example we look at solving batch processing problems with algorithm recursion and pipelining.
client = Algorithmia.client()
graph, label_index = load_model(client)
#-- IMPORANT --# be aware of the algorithm version you're calling, as this is self-referential while you're doing development you may need to replace this variable with a version hash.
# TODO: We'll improve this experience in the future
THIS_ALGO = "zeryx/recursive_image_example/0.1.x"
# The number of recursive requests that will be open at any time, this keeps us from overwelming the development environment by constraining our resources to some reasonable maximum.
NUM_PARALLEL_REQUESTS = 10
# The maximum amount of work each algorithm will request will handle before recursing
MAX_WORK_PER_REQUEST = 8
def get_image(url):
"""Uses the Smart Image Downloader algorithm to format and download images from the web or other places."""
input = {'image': str(url)}
output_url = client.algo("util/SmartImageDownloader/0.2.x").pipe(input).result['savePath'][0]
temp_file = client.file(output_url).getFile().name
os.rename(temp_file, temp_file + '.' + output_url.split('.')[-1])
return temp_file + '.' + output_url.split('.')[-1]
def inference(image):
image_data = tf.gfile.FastGFile(image, 'rb').read()
with tf.Session(graph=graph) as sess:
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
predictions = sess.run(softmax_tensor,
{'DecodeJpeg/contents:0': image_data})
predictions = np.squeeze(predictions)
tags = []
top_k = predictions.argsort()[-5:][::-1]
for node_id in top_k:
human_string = id_to_string(label_index, node_id)
score = predictions[node_id]
result = {}
result['class'] = human_string
result['confidence'] = score.item()
tags.append(result)
results = {}
results['tags'] = tags
return results
def algorithm_recursion_(input, errorQ):
"""This function will create a threadpool and make parallel calls to _algo, and return a callback.
As you can see, we limit the pool size by some value to ensure we don't overload anthing.
Besides that, we also blend the errorQ object into each chunk of data that we're passing into _algo.
If desired, a pool.starmap() can simplify this process."""
pool = Pool(NUM_PARALLEL_REQUESTS)
chunks = _chunks(input, MAX_WORK_PER_REQUEST)
process_data = [(chunk, errorQ) for chunk in chunks]
async_ops = pool.starmap_async(_algo, process_data)
return async_ops
def _chunks(l, n):
"""Yield successive n-sized chunks from l."""
for i in range(0, len(l), n):
yield l[i:i + n]
def _algo(algo_data, errorQ):
"""The primary working algorithm for our parallel threads, makes parallel requests and checks if errors exist"""
try:
if errorQ.empty():
print("processing chunk..")
response = client.algo(THIS_ALGO).pipe(algo_data).result
print("finished chunk..")
return response
else:
return None
except Exception as e:
errorQ.put(e)
def batch_apply(input):
"""Simple sequential small batch processing, can be made parallel if necessary"""
results = []
for image in input:
results.append(apply(image))
return results
def apply(input):
if isinstance(input, str):
image = get_image(input)
results = {"image": input, "results": inference(image)}
elif isinstance(input, dict) and "image" in input:
image = get_image(input['image'])
results = {"image": input['image'], "results": inference(image)}
elif isinstance(input, list):
# If we do have a small list, it doesn't make sense to send off each request to a different machine,
# it might just be easier to process it here.
if len(input) < MAX_WORK_PER_REQUEST:
results = batch_apply(input)
else:
# Lets take some work for this algorithm to work on, before we pass the remainder to our recursively
# called algorithms
input_for_this_worker = input[:MAX_WORK_PER_REQUEST]
remaining_work = input[MAX_WORK_PER_REQUEST:]
# This object allows us to pass error messages and exceptions between threads, which can be very useful
# when things don't go as planned
manager = Manager()
errorQ = manager.Queue()
# We spin off the recursive / threading components of the algorithm to a separate thread so that we can
# process this algorithm's batch of work concurrently
eventual_remote_results = algorithm_recursion_(remaining_work, errorQ)
local_results = batch_apply(input_for_this_worker)
concurrent_results = eventual_remote_results.get()
# Make sure to check your error Q before returning a result, if it has errors we should return them as
# something went wrong
if errorQ.empty():
results = local_results + concurrent_results
else:
raise errorQ.get()
else:
raise Exception("Input format invalid")
return results
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment