Skip to content

Instantly share code, notes, and snippets.

@hisashi-komine
Created February 23, 2018 07:04
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save hisashi-komine/dcc4d3e0271c729f1450af0b27f81cb2 to your computer and use it in GitHub Desktop.
Save hisashi-komine/dcc4d3e0271c729f1450af0b27f81cb2 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
import argparse
import datetime
import os
import yaml
def main(args):
import tensorflow as tf
from keras import backend as K
from keras.models import load_model
from tensorflow.python.framework import graph_io
from tensorflow.python.tools import freeze_graph
os.makedirs(args.output_dir, exist_ok=True)
input_graph_path = os.path.join(args.output_dir, 'model.pbtxt')
output_graph_path = os.path.join(args.output_dir, 'model.pb')
checkpoint_prefix = os.path.join(args.output_dir, 'ckpt')
interface_path = os.path.join(args.output_dir, 'interface.yml')
model = load_model(args.model_file)
# Save graph def & variables
sess = K.get_session()
graph_io.write_graph(sess.graph, '.', input_graph_path)
checkpoint_path = tf.train.Saver().save(sess, checkpoint_prefix)
# Save node names
with open(interface_path, 'w') as f:
yaml.dump(
{'input': model.input.name, 'output': model.output.name},
f,
default_flow_style=False
)
# Freeze graph to pb file
freeze_graph.freeze_graph(
input_graph=input_graph_path,
input_saver=None,
input_binary=False,
input_checkpoint=checkpoint_path,
output_node_names=model.output.name.split(':')[0],
restore_op_name=None,
filename_tensor_name=None,
output_graph=output_graph_path,
clear_devices=True,
initializer_nodes=None,
)
if __name__ == "__main__":
suffix = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
parser = argparse.ArgumentParser(description='Training model.')
parser.add_argument(
'-m', '--model-file',
type=str,
default=os.path.join(os.path.dirname(__file__), 'model.h5'),
help='A file of model to train'
)
parser.add_argument(
'-o', '--output-dir',
type=str,
default=os.path.join(
os.path.dirname(__file__),
'export.{}'.format(suffix)
),
help='Path of output directory'
)
parser.add_argument(
'--output-suffix',
type=str,
default=suffix,
help='Suffix of output directory'
)
main(parser.parse_args())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment