Skip to content

Instantly share code, notes, and snippets.

Created February 2, 2018 18:52
Show Gist options
  • Save datlife/2c39a1893e689130c9a18ff14ec452a0 to your computer and use it in GitHub Desktop.
Save datlife/2c39a1893e689130c9a18ff14ec452a0 to your computer and use it in GitHub Desktop.
Export pre-trained TF Object Detection API model to Tensorflow Serving
Thiss script would convert a pre-trained TF model to a servable version for TF Serving.
A pre-trained model can be downloaded here
* A directory contains pretrained model (can be download above).
* Edit three arguments `frozen_graph`, `model_name`, `base_dir` accordingly
* A TF Servable model
from __future__ import print_function
import os
import tensorflow as tf
# Load frozen graph utils
from tensorflow.python.util import compat
from tensorflow.python.platform import gfile
# TF Libraries to export model into .pb file
from tensorflow.python.client import session
from tensorflow.python.saved_model import signature_constants
from tensorflow.core.protobuf import rewriter_config_pb2
from import TransformGraph
# This is directory downloaded from official pretrained models zoo
frozen_graph = './detector/ssd_inception_v2_coco/frozen_inference_graph.pb'
# Name of the inference model. It is important to set the name properly.
# The client will need to know the exact model name in order to make a prediction request.
model_name = 'ssd_inception_v2_coco'
# Output the servable model. In this example, I save in the same directory.
base_dir = './detector'
def _main_():
# #################
# Setup export path
version = 1
output_dir = os.path.join(base_dir, model_name)
export_path = os.path.join(output_dir, str(version))
# ######################
# Interference Pipeline
# ######################
input_names = 'image_tensor'
output_names = ['detection_boxes', 'detection_classes', 'detection_scores', 'num_detections']
with tf.Session() as sess:
input_tensor = tf.placeholder(dtype=tf.uint8, shape=(None, None, None, 3), name=input_names)
# ###################
# load frozen graph
# ###################
graph_def = load_graph_from_pb(frozen_graph)
outputs = tf.import_graph_def(graph_def,
input_map={'image_tensor': input_tensor},
outputs = [sess.graph.get_tensor_by_name( +':0')for ops in outputs]
outputs = dict(zip(output_names, outputs))
# #####################
# Quantize Frozen Model
# #####################
transforms = ["add_default_attributes",
"quantize_weights", "round_weights",
"fold_batch_norms", "fold_old_batch_norms"]
quantized_graph = TransformGraph(input_graph_def=graph_def,
# #####################
# Export to TF Serving#
# #####################
# Reference:
with tf.Graph().as_default():
tf.import_graph_def(quantized_graph, name='')
# Optimizing graph
rewrite_options = rewriter_config_pb2.RewriterConfig(optimize_tensor_layout=True)
graph_options = tf.GraphOptions(rewrite_options=rewrite_options, infer_shapes=True)
# Build model for TF Serving
config = tf.ConfigProto(graph_options=graph_options)
# @TODO: add XLA for higher performance (AOT for ARM, JIT for x86/GPUs)
# config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1
with session.Session(config=config) as sess:
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
tensor_info_inputs = {'inputs': tf.saved_model.utils.build_tensor_info(input_tensor)}
tensor_info_outputs = {}
for k, v in outputs.items():
tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(v)
detection_signature = (
inputs = tensor_info_inputs,
outputs = tensor_info_outputs,
method_name= signature_constants.PREDICT_METHOD_NAME))
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={'predict_images': detection_signature,
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: detection_signature,
print("\n\nModel is ready for TF Serving. (saved at {}/saved_model.pb)".format(export_path))
def load_graph_from_pb(model_filename):
with tf.Session() as sess:
with gfile.FastGFile(model_filename, 'rb') as f:
data = compat.as_bytes(
graph_def = tf.GraphDef()
return graph_def
if __name__ == '__main__':
Copy link

nairsgithub commented Sep 18, 2018

I am not getting variables. by using this script

Copy link

this does not create Variables.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment