Skip to content

Instantly share code, notes, and snippets.

@lgutzwil
Created February 27, 2019 01:09
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save lgutzwil/b04decf444d3c86fe5414956a3654623 to your computer and use it in GitHub Desktop.
Save lgutzwil/b04decf444d3c86fe5414956a3654623 to your computer and use it in GitHub Desktop.
DeepLab Export Code for Blog Post
# This script based on code originally published at
# https://github.com/tensorflow/models/blob/master/research/deeplab/export_model.py
#
# ORIGINAL VERSION:
#
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tensorflow as tf
from tensorflow.python.client import session
from deeplab import common
from deeplab import input_preprocess
from deeplab import model
slim = tf.contrib.slim
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string("checkpoint_path", None, "Checkpoint path")
flags.DEFINE_string("export_dir", None,
"Base directory to output Tensorflow SavedModel.")
flags.DEFINE_integer("model_version", 1, "Model version number.")
flags.DEFINE_integer("num_classes", 21, "Number of classes.")
flags.DEFINE_multi_integer("crop_size", [513, 513],
"Crop size [height, width].")
# For `xception_65`, use atrous_rates = [12, 24, 36] if output_stride = 8, or
# rates = [6, 12, 18] if output_stride = 16. For `mobilenet_v2`, use None. Note
# one could use different atrous_rates/output_stride during training/evaluation.
flags.DEFINE_multi_integer("atrous_rates", None,
"Atrous rates for atrous spatial pyramid pooling.")
flags.DEFINE_integer("output_stride", 8,
"The ratio of input to output spatial resolution.")
# Change to [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] for multi-scale inference.
flags.DEFINE_multi_float("inference_scales", [1.0],
"The scales to resize images for inference.")
flags.DEFINE_bool("add_flipped_images", False,
"Add flipped images during inference or not.")
def generate_input_and_output_tensors():
input_image = tf.placeholder(tf.uint8, [1, None, None, 3])
original_image_size = tf.shape(input_image)[1:3]
# resize the image
height = tf.cast(original_image_size[0], tf.float64)
width = tf.cast(original_image_size[1], tf.float64)
# Squeeze the dimension in axis=0 since preprocess_image_and_label assumes
# image to be 3-D.
image = tf.squeeze(input_image, axis=0)
# Resize the image so that height <= FLAGS.crop_size[0]
# and width <= FLAGS.crop_size[1]
height_ratio = FLAGS.crop_size[0] / original_image_size[0]
width_ratio = FLAGS.crop_size[1] / original_image_size[1]
resize_ratio = tf.minimum(height_ratio, width_ratio)
target_height = tf.to_int32(tf.floor(resize_ratio * height))
target_width = tf.to_int32(tf.floor(resize_ratio * width))
target_size = (target_height, target_width)
image = tf.image.resize_images(
image,
target_size,
method=tf.image.ResizeMethod.BILINEAR,
align_corners=True
)
# apply preprocessing
resized_image, image, _ = input_preprocess.preprocess_image_and_label(
image,
label=None,
crop_height=FLAGS.crop_size[0],
crop_width=FLAGS.crop_size[1],
min_resize_value=FLAGS.min_resize_value,
max_resize_value=FLAGS.max_resize_value,
resize_factor=FLAGS.resize_factor,
is_training=False,
model_variant=FLAGS.model_variant
)
# run inference
resized_image_size = tf.shape(resized_image)[:2]
# Expand the dimension in axis=0, since the following operations assume the
# image to be 4-D.
image = tf.expand_dims(image, 0)
model_options = common.ModelOptions(
outputs_to_num_classes={common.OUTPUT_TYPE: FLAGS.num_classes},
crop_size=FLAGS.crop_size,
atrous_rates=FLAGS.atrous_rates,
output_stride=FLAGS.output_stride
)
if tuple(FLAGS.inference_scales) == (1.0,):
tf.logging.info("Exported model performs single-scale inference.")
predictions = model.predict_labels(
image,
model_options=model_options,
image_pyramid=FLAGS.image_pyramid
)
else:
tf.logging.info("Exported model performs multi-scale inference.")
predictions = model.predict_labels_multi_scale(
image,
model_options=model_options,
eval_scales=FLAGS.inference_scales,
add_flipped_images=FLAGS.add_flipped_images
)
predictions = tf.cast(predictions[common.OUTPUT_TYPE], tf.float32)
# Crop the valid regions from the predictions.
semantic_predictions = tf.slice(
predictions,
[0, 0, 0],
[1, resized_image_size[0], resized_image_size[1]]
)
# Resize back the prediction to the desired output size.
def _resize_label(label, label_size):
# Expand dimension of label to [1, height, width, 1] for
# resize operation.
label = tf.expand_dims(label, 3)
resized_label = tf.image.resize_images(
label,
label_size,
method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
align_corners=True
)
return tf.cast(tf.squeeze(resized_label, 3), tf.int64)
semantic_predictions = _resize_label(
semantic_predictions,
original_image_size
)
semantic_predictions = tf.identity(
semantic_predictions
)
# return the input and output tensors
return input_image, semantic_predictions
def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
export_path = "{}/{}".format(FLAGS.export_dir, FLAGS.model_version)
tf.logging.info("Prepare to export model to: %s", export_path)
with tf.Graph().as_default():
with tf.Session() as sess:
input_image, semantic_predictions = generate_input_and_output_tensors()
saver = tf.train.Saver(tf.model_variables())
saver.restore(sess, FLAGS.checkpoint_path)
builder = tf.saved_model.builder.SavedModelBuilder(
export_path
)
tensor_info_image = tf.saved_model.utils.build_tensor_info(
input_image
)
tensor_info_inputs = {
"inputs": tensor_info_image
}
tensor_info_output = tf.saved_model.utils.build_tensor_info(
semantic_predictions
)
tensor_info_outputs = {
"segmentation_map": tensor_info_output
}
detection_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs=tensor_info_inputs,
outputs=tensor_info_outputs,
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
)
)
builder.add_meta_graph_and_variables(
sess,
[tf.saved_model.tag_constants.SERVING],
signature_def_map={
"detection_signature": detection_signature
}
)
builder.save()
if __name__ == "__main__":
flags.mark_flag_as_required("checkpoint_path")
flags.mark_flag_as_required("export_dir")
flags.mark_flag_as_required("model_version")
tf.app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment