Created
June 19, 2018 23:25
-
-
Save jihyeonRyu/3295d425f0ea5b9036dd7ff67f87f35b to your computer and use it in GitHub Desktop.
save tensorflow model to .ckpt and .pb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
load model | |
""" | |
def load_model(model_dir): | |
start = time.time() | |
model_exp = os.path.expanduser(model_dir) | |
if (os.path.isfile(model_exp)): | |
print('Model filename: %s' % model_exp) | |
with gfile.FastGFile(model_exp, 'rb') as f: | |
graph_def = tf.GraphDef() | |
graph_def.ParseFromString(f.read()) | |
tf.import_graph_def(graph_def, name='') | |
else: | |
print('Model directory: %s' % model_exp) | |
meta_file, ckpt_file = get_model_filenames(model_exp) | |
print('Metagraph file: %s' % meta_file) | |
print('Checkpoint file: %s' % ckpt_file) | |
saver = tf.train.import_meta_graph(os.path.join(model_exp, meta_file)) | |
saver.restore(tf.get_default_session(), os.path.join(model_exp, ckpt_file)) | |
end = time.time() | |
print("model loading time: ", end-start) | |
def get_model_filenames(model_dir): | |
files = os.listdir(model_dir) | |
meta_files = [s for s in files if s.endswith('.meta')] | |
if len(meta_files) == 0: | |
raise ValueError('No meta file found in the output directory (%s)' % model_dir) | |
elif len(meta_files) > 1: | |
raise ValueError('There should not be more than one meta file in the output directory (%s)' % model_dir) | |
meta_file = meta_files[0] | |
ckpt = tf.train.get_checkpoint_state(model_dir) | |
if ckpt and ckpt.model_checkpoint_path: | |
ckpt_file = os.path.basename(ckpt.model_checkpoint_path) | |
return meta_file, ckpt_file | |
meta_files = [s for s in files if '.ckpt' in s] | |
max_step = -1 | |
for f in files: | |
step_str = re.match(r'(^model-[\w\- ]+.ckpt-(\d+))', f) | |
if step_str is not None and len(step_str.groups()) >= 2: | |
step = int(step_str.groups()[1]) | |
if step > max_step: | |
max_step = step | |
ckpt_file = step_str.groups()[0] | |
return meta_file, ckpt_file | |
""" | |
save ckpt meta file | |
""" | |
def save_variables_and_metagraph(sess, saver, summary_writer, model_dir, model_name, step): | |
# Save the model checkpoint | |
print('Saving variables') | |
start_time = time.time() | |
checkpoint_path = os.path.join(model_dir, 'model-%s.ckpt' % model_name) | |
saver.save(sess, checkpoint_path, global_step=step, write_meta_graph=False) | |
save_time_variables = time.time() - start_time | |
print('Variables saved in %.2f seconds' % save_time_variables) | |
metagraph_filename = os.path.join(model_dir, 'model-%s.meta' % model_name) | |
save_time_metagraph = 0 | |
if not os.path.exists(metagraph_filename): | |
print('Saving metagraph') | |
start_time = time.time() | |
saver.export_meta_graph(metagraph_filename) | |
save_time_metagraph = time.time() - start_time | |
print('Metagraph saved in %.2f seconds' % save_time_metagraph) | |
summary = tf.Summary() | |
# pylint: disable=maybe-no-member | |
summary.value.add(tag='time/save_variables', simple_value=save_time_variables) | |
summary.value.add(tag='time/save_metagraph', simple_value=save_time_metagraph) | |
summary_writer.add_summary(summary, step) | |
""" | |
save ckpt to pb file | |
""" | |
def save_ckpt_to_pb(sess, model_dir): | |
load_model(model_dir) | |
# Retrieve the protobuf graph definition and fix the batch norm nodes | |
input_graph_def = sess.graph.as_graph_def() | |
# Freeze the graph def | |
output_graph_def = freeze_graph_def(sess, input_graph_def, 'embeddings') | |
output_graph = os.path.join(model_dir, "model.pb") | |
with tf.gfile.GFile(output_graph, 'wb') as f: | |
f.write(output_graph_def.SerializeToString()) | |
print("%d ops in the final graph: %s" % (len(output_graph_def.node), output_graph)) | |
def freeze_graph_def(sess, input_graph_def, output_node_names): | |
for node in input_graph_def.node: | |
if node.op == 'RefSwitch': | |
node.op = 'Switch' | |
for index in xrange(len(node.input)): | |
if 'moving_' in node.input[index]: | |
node.input[index] = node.input[index] + '/read' | |
elif node.op == 'AssignSub': | |
node.op = 'Sub' | |
if 'use_locking' in node.attr: del node.attr['use_locking'] | |
elif node.op == 'AssignAdd': | |
node.op = 'Add' | |
if 'use_locking' in node.attr: del node.attr['use_locking'] | |
# Get the list of important nodes | |
whitelist_names = [] | |
for node in input_graph_def.node: | |
if (node.name.startswith('InceptionResnetV1') or node.name.startswith('embeddings') or | |
node.name.startswith('phase_train') or node.name.startswith('Bottleneck') or node.name.startswith( | |
'Logits')): | |
whitelist_names.append(node.name) | |
# Replace all the variables in the graph with constants of the same values | |
output_graph_def = tf.graph_util.convert_variables_to_constants( | |
sess, input_graph_def, output_node_names.split(","), | |
variable_names_whitelist=whitelist_names) | |
return output_graph_def |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment