Skip to content

Instantly share code, notes, and snippets.

@edge0701
Created June 19, 2017 15:29
Show Gist options
  • Save edge0701/dd0550cc46f83b7e0e0fd0c5b23fd392 to your computer and use it in GitHub Desktop.
Save edge0701/dd0550cc46f83b7e0e0fd0c5b23fd392 to your computer and use it in GitHub Desktop.
Export tensorflow model
import tensorflow as tf
from model import select_model, get_checkpoint
from tensorflow.python.framework import graph_util
from tensorflow.contrib.learn.python.learn.utils import export
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import utils
RESIZE_FINAL = 227
GENDER_LIST =['M','F']
AGE_LIST = ['(0, 2)','(4, 6)','(8, 12)','(15, 20)','(25, 32)','(38, 43)','(48, 53)','(60, 100)']
tf.app.flags.DEFINE_string('checkpoint', 'checkpoint',
'Checkpoint basename')
tf.app.flags.DEFINE_string('class_type', 'age',
'Classification type (age|gender)')
tf.app.flags.DEFINE_string('device_id', '/cpu:0',
'What processing unit to execute inference on')
tf.app.flags.DEFINE_string('model_dir', '',
'Model directory (where training data lives)')
tf.app.flags.DEFINE_string('export_dir', '/tmp/tf_exported_model/0',
'Export directory')
tf.app.flags.DEFINE_string('model_type', 'default',
'Type of convnet')
tf.app.flags.DEFINE_string('requested_step', '', 'Within the model directory, a requested step to restore e.g., 9000')
FLAGS = tf.app.flags.FLAGS
def main(argv=None):
with tf.Session() as sess:
label_list = AGE_LIST if FLAGS.class_type == 'age' else GENDER_LIST
nlabels = len(label_list)
model_fn = select_model(FLAGS.model_type)
with tf.device(FLAGS.device_id):
images = tf.placeholder(tf.float32, [None, RESIZE_FINAL, RESIZE_FINAL, 3])
logits = model_fn(nlabels, images, 1, False)
init = tf.global_variables_initializer()
requested_step = FLAGS.requested_step if FLAGS.requested_step else None
checkpoint_path = '%s' % (FLAGS.model_dir)
model_checkpoint_path, global_step = get_checkpoint(checkpoint_path, requested_step, FLAGS.checkpoint)
saver = tf.train.Saver()
saver.restore(sess, model_checkpoint_path)
prediction_signature = signature_def_utils.predict_signature_def(
inputs={'images': images},
outputs={'output': logits})
#below is wrong.
classification_signature = signature_def_utils.classification_signature_def(
examples=images,
classes=logits,
scores=logits)
builder = saved_model_builder.SavedModelBuilder(FLAGS.export_dir)
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
sess, [tag_constants.SERVING],
signature_def_map={
'inputs': prediction_signature,
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
classification_signature,
},
legacy_init_op=legacy_init_op)
builder.save()
if __name__ == '__main__':
tf.app.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment