Skip to content

Instantly share code, notes, and snippets.

@zldrobit
Last active October 18, 2018 02:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zldrobit/f9b59298afcd680ef89ee8660b3be54b to your computer and use it in GitHub Desktop.
Save zldrobit/f9b59298afcd680ef89ee8660b3be54b to your computer and use it in GitHub Desktop.
import tensorflow as tf
import argparse
import os
parser = argparse.ArgumentParser(description='Generate a saved model.')
parser.add_argument('--export_model_dir', type=str, default='./saved_model/the_model', help='export model directory')
parser.add_argument('--model_version', type=int, default=1, help='model version')
parser.add_argument('--model', type=str, default='the_model.pb', help='model pb file')
parser.add_argument("--input_tensor", default="input:0", help="input tensor", type=str)
parser.add_argument("--output_tensor", default="output:0", help="output tensor", type=str)
args = parser.parse_args()
with tf.Session() as sess:
with tf.gfile.GFile(args.model, "rb") as f:
restored_graph_def = tf.GraphDef()
restored_graph_def.ParseFromString(f.read())
tf.import_graph_def(
restored_graph_def,
input_map=None,
return_elements=None,
name=""
)
input_tensor = tf.get_default_graph().get_tensor_by_name(args.input_tensor)
output_tensor = tf.get_default_graph().get_tensor_by_name(args.output_tensor)
print('input tensor shape', input_tensor.shape)
# Create SavedModelBuilder class
# defines where the model will be exported
export_path_base = args.export_model_dir
export_path = os.path.join(
tf.compat.as_bytes(export_path_base),
tf.compat.as_bytes(str(args.model_version)))
print('Exporting trained model to', export_path)
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
# Creates the TensorInfo protobuf objects that encapsulates the input/output tensors
tensor_info_input = tf.saved_model.utils.build_tensor_info(input_tensor)
# tensor_info_height = tf.saved_model.utils.build_tensor_info(image_height_tensor)
# tensor_info_width = tf.saved_model.utils.build_tensor_info(image_width_tensor)
# output tensor info
tensor_info_output = tf.saved_model.utils.build_tensor_info(output_tensor)
# Defines the model signatures, uses the TF Predict API
# It receives an image and its dimensions and output the segmentation mask
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'input': tensor_info_input},
outputs={'output': tensor_info_output},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
'output':
prediction_signature,
})
# export the model
builder.save(as_text=True)
print('Done exporting!')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment