Skip to content

Instantly share code, notes, and snippets.

@zeryx
Last active September 7, 2017 20:06
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zeryx/ce2c7b620f8141d5ad5e66200954afd1 to your computer and use it in GitHub Desktop.
Save zeryx/ce2c7b620f8141d5ad5e66200954afd1 to your computer and use it in GitHub Desktop.
performance improvements with tensorflow & algorithmia
# This code example describes how to pre-load a tensorflow graph file
# into an Algorithmia container and load the graph into memory.
# This approach allows us to preserve the graph in system memory between API calls,
# Improving overall performance.
# We also document how to evict the tensorflow GPU memory context between API requests by
# moving it to a separate thread, and how to define the amount of GPU memory an algorithm uses.
# These GPU tweaks significantly improve performance on Algorithmia's infrastructure
import Algorithmia
import multiprocessing
import numpy as np
import tensorflow as tf
import tarfile
client = Algorithmia.client()
## we define the graph and category index in advance, and provide them with default values.
GRAPH, CATEGORY_INDEX, MODEL_NAME = (tf.Graph(), '', 'default')
def get_image(url):
output_url = client.algo(SIMD_ALGO).pipe({'image': str(url)}).result['savePath'][0]
temp_file = client.file(output_url).getFile().name
os.rename(temp_file, temp_file + '.' + output_url.split('.')[-1])
return temp_file + '.' + output_url.split('.')[-1]
# This function assumes that your model is in a tar.gz format, if otherwise extract your file accordingly.
def download_model(model_name):
global GRAPH
global CAT_INDEX
download_base = 'data://some/base/'
model_file = model_name + '.tar.gz'
# Path to frozen detection graph. This is the actual model that is used for the object detection.
path_to_graph = model_name + '/frozen_inference_graph.pb'
# List of the strings that is used to add correct label for each box.
path_to_labels = client.file("data://path/to/labels.pbtxt").getFile().name
print(model_name)
print(MODEL_NAME)
if model_name != MODEL_NAME:
print('model name not the same, reloading...')
if not os.path.isfile(path_to_graph):
try:
local_file = client.file(download_base+model_file).getFile().name
except Exception:
raise AlgorithmError("AlgoError3000: invalid model name.")
tar_file = tarfile.open(local_file)
for file in tar_file.getmembers():
file_name = os.path.basename(file.name)
if 'frozen_inference_graph.pb' in file_name:
tar_file.extract(file, os.getcwd())
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(path_to_graph, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
label_map = label_map_util.load_labelmap(path_to_labels)
categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)
category_index = label_map_util.create_category_index(categories)
## Set the global variables so they can be used later without repeating this process.
GRAPH = detection_graph
CATEGORY_INDEX = category_index
MODEL_NAME = model_name
## We don't activate allow_growth as having well defined gpu memory use profiles helps in scheduling.
## It also offers a slight performance improvement, which is always nice :)
def generate_gpu_config(memory_fraction):
config = tf.ConfigProto()
# config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = memory_fraction
return config
# Since this function is never used outside of a multiprocessing Process, it returns nothing.
# However, it does mutate the result object defined in apply.
def execute_tensorflow(graph, image,
category_index, max_boxes, min_score, result):
with graph.as_default():
with tf.Session(graph=GRAPH, config=generate_gpu_config(GPU_MEMORY_FRACTION)) as sess:
image = Image.open(image).convert('RGB')
image_np = load_image_into_numpy_array(image)
height, width, _ = image_np.shape
image_np_expanded = np.expand_dims(image_np, axis=0)
image_tensor = graph.get_tensor_by_name('image_tensor:0')
boxes = graph.get_tensor_by_name('detection_boxes:0')
scores = graph.get_tensor_by_name('detection_scores:0')
classes = graph.get_tensor_by_name('detection_classes:0')
num_detections = graph.get_tensor_by_name('num_detections:0')
(boxes, scores, classes, num_detections) = sess.run(
[boxes, scores, classes, num_detections],
feed_dict={image_tensor: image_np_expanded})
boxes = np.squeeze(boxes)
classes = np.squeeze(classes).astype(np.int32)
scores = np.squeeze(scores)
prepare_output(height, width, boxes, classes, scores, category_index, result)
def apply(input):
model_name = "some_default_model_file"
if isinstance(input, str):
image = get_image(input)
elif isinstance(input, dict):
# - process your input fields here -
if 'model' in input:
model_name = input['model']
download_model(model_name)
# don't forget to put all of the completed work into a multiprocessing friendly structure, like this list format.
result = multiprocessing.Manager().list()
# execute_tensorflow is run in a separate thread so that when the job is complete we can kill the GPU context
p = multiprocessing.Process(target=execute_tensorflow,
args=(GRAPH, image,
CATEGORY_INDEX, result))
p.start()
p.join()
result = [x for x in result]
if output:
im = Image.open(image).convert('RGB')
image = transform_image(path, output, box_output)
return {'data': box_output, 'image': image}
else:
return {'data': box_output}
download_model("some_default_model_file")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment