Skip to content

Instantly share code, notes, and snippets.

@monklof
Last active June 8, 2022 08:34
Show Gist options
  • Save monklof/bd8c7f230eddd03f13d1ff0959a0c110 to your computer and use it in GitHub Desktop.
Save monklof/bd8c7f230eddd03f13d1ff0959a0c110 to your computer and use it in GitHub Desktop.
convert tensorflow's checkpoint to savedmodel (ResNetV2)
# -*- coding: utf-8 -*-
import os.path
# This is a placeholder for a Google-internal import.
import tensorflow as tf
import tensorflow.contrib.slim as slim
from nets import resnet_v2
from preprocessing import vgg_preprocessing as vgg
tf.app.flags.DEFINE_string('checkpoint_dir', '/opt/zhoulinyuan/inception_v4',
"""Directory where to read training checkpoints.""")
tf.app.flags.DEFINE_string('output_dir', '/tmp/inception_v4_porn_output',
"""Directory where to export inference model.""")
tf.app.flags.DEFINE_integer('model_version', 4,
"""Version number of the model.""")
tf.app.flags.DEFINE_integer('image_size', 224,
"""Needs to provide same value as in training.""")
FLAGS = tf.app.flags.FLAGS
NUM_CLASSES = 3
NUM_TOP_CLASSES = 3
def export():
# Create index->synset mapping
synsets = []
with tf.Graph().as_default():
# Build inference model.
# Please refer to Tensorflow inception model for details.
# Input transformation.
serialized_tf_example = tf.placeholder(tf.string, name='tf_example')
feature_configs = {
'image/encoded': tf.FixedLenFeature(
shape=[], dtype=tf.string),
}
tf_example = tf.parse_example(serialized_tf_example, feature_configs)
jpegs = tf_example['image/encoded']
images = tf.map_fn(preprocess_image, jpegs, dtype=tf.float32)
# Run inference.
# logits, _ = inception_model.inference(images, NUM_CLASSES + 1)
# Run inference.
with slim.arg_scope(resnet_v2.resnet_arg_scope()):
logits, _ = resnet_v2.resnet_v2_50(images, NUM_CLASSES, is_training=False)
logits = tf.nn.softmax(logits)
# Transform output to topK result.
values, indices = tf.nn.top_k(logits, NUM_TOP_CLASSES)
class_descriptions = ['0_xx', '1_yy', '2_zz']
class_tensor = tf.constant(class_descriptions)
table = tf.contrib.lookup.index_to_string_table_from_tensor(class_tensor)
classes = table.lookup(tf.to_int64(indices))
saver = tf.train.Saver()
with tf.Session() as sess:
# Restore variables from training checkpoints.
saver.restore(sess, FLAGS.checkpoint_dir)
# keys = sess.graph.get_all_collection_keys()
sess.graph.clear_collection('resnet_v2_50/_end_points')
# Export inference model.
output_path = os.path.join(
tf.compat.as_bytes(FLAGS.output_dir),
tf.compat.as_bytes(str(FLAGS.model_version)))
print 'Exporting trained model to', output_path
builder = tf.saved_model.builder.SavedModelBuilder(output_path)
# Build the signature_def_map.
classify_inputs_tensor_info = tf.saved_model.utils.build_tensor_info(
serialized_tf_example)
classes_output_tensor_info = tf.saved_model.utils.build_tensor_info(
classes)
scores_output_tensor_info = tf.saved_model.utils.build_tensor_info(values)
classification_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={
tf.saved_model.signature_constants.CLASSIFY_INPUTS:
classify_inputs_tensor_info
},
outputs={
tf.saved_model.signature_constants.CLASSIFY_OUTPUT_CLASSES:
classes_output_tensor_info,
tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES:
scores_output_tensor_info
},
method_name=tf.saved_model.signature_constants.
CLASSIFY_METHOD_NAME))
predict_inputs_tensor_info = tf.saved_model.utils.build_tensor_info(jpegs)
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'images': predict_inputs_tensor_info},
outputs={
'classes': classes_output_tensor_info,
'scores': scores_output_tensor_info
},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
))
legacy_init_op = tf.group(
tf.tables_initializer(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
'predict_images':
prediction_signature,
tf.saved_model.signature_constants.
DEFAULT_SERVING_SIGNATURE_DEF_KEY:
classification_signature,
},
legacy_init_op=legacy_init_op)
builder.save()
print 'Successfully exported model to %s' % FLAGS.output_dir
def preprocess_image(image_buffer):
"""Preprocess JPEG encoded bytes to 3D float Tensor."""
# Decode the string as an RGB JPEG.
# Note that the resulting image contains an unknown height and width
# that is set dynamically by decode_jpeg. In other words, the height
# and width of image is unknown at compile-time.
image = tf.image.decode_jpeg(image_buffer, channels=3)
# image = vgg._aspect_preserving_resize(image, vgg._RESIZE_SIDE_MAX)
image = vgg._aspect_preserving_resize(image, vgg._RESIZE_SIDE_MIN)
image = vgg._central_crop([image], FLAGS.image_size, FLAGS.image_size)[0]
image.set_shape([FLAGS.image_size, FLAGS.image_size, 3])
image = tf.to_float(image)
image = vgg._mean_image_subtraction(image, [vgg._R_MEAN, vgg._G_MEAN, vgg._B_MEAN])
return image
def main(unused_argv=None):
export()
if __name__ == '__main__':
tf.app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment