Skip to content

Instantly share code, notes, and snippets.

@jihyeonRyu
Created June 19, 2018 23:25
Show Gist options
  • Save jihyeonRyu/3295d425f0ea5b9036dd7ff67f87f35b to your computer and use it in GitHub Desktop.
Save jihyeonRyu/3295d425f0ea5b9036dd7ff67f87f35b to your computer and use it in GitHub Desktop.
save tensorflow model to .ckpt and .pb
"""
load model
"""
def load_model(model_dir):
start = time.time()
model_exp = os.path.expanduser(model_dir)
if (os.path.isfile(model_exp)):
print('Model filename: %s' % model_exp)
with gfile.FastGFile(model_exp, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
else:
print('Model directory: %s' % model_exp)
meta_file, ckpt_file = get_model_filenames(model_exp)
print('Metagraph file: %s' % meta_file)
print('Checkpoint file: %s' % ckpt_file)
saver = tf.train.import_meta_graph(os.path.join(model_exp, meta_file))
saver.restore(tf.get_default_session(), os.path.join(model_exp, ckpt_file))
end = time.time()
print("model loading time: ", end-start)
def get_model_filenames(model_dir):
files = os.listdir(model_dir)
meta_files = [s for s in files if s.endswith('.meta')]
if len(meta_files) == 0:
raise ValueError('No meta file found in the output directory (%s)' % model_dir)
elif len(meta_files) > 1:
raise ValueError('There should not be more than one meta file in the output directory (%s)' % model_dir)
meta_file = meta_files[0]
ckpt = tf.train.get_checkpoint_state(model_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_file = os.path.basename(ckpt.model_checkpoint_path)
return meta_file, ckpt_file
meta_files = [s for s in files if '.ckpt' in s]
max_step = -1
for f in files:
step_str = re.match(r'(^model-[\w\- ]+.ckpt-(\d+))', f)
if step_str is not None and len(step_str.groups()) >= 2:
step = int(step_str.groups()[1])
if step > max_step:
max_step = step
ckpt_file = step_str.groups()[0]
return meta_file, ckpt_file
"""
save ckpt meta file
"""
def save_variables_and_metagraph(sess, saver, summary_writer, model_dir, model_name, step):
# Save the model checkpoint
print('Saving variables')
start_time = time.time()
checkpoint_path = os.path.join(model_dir, 'model-%s.ckpt' % model_name)
saver.save(sess, checkpoint_path, global_step=step, write_meta_graph=False)
save_time_variables = time.time() - start_time
print('Variables saved in %.2f seconds' % save_time_variables)
metagraph_filename = os.path.join(model_dir, 'model-%s.meta' % model_name)
save_time_metagraph = 0
if not os.path.exists(metagraph_filename):
print('Saving metagraph')
start_time = time.time()
saver.export_meta_graph(metagraph_filename)
save_time_metagraph = time.time() - start_time
print('Metagraph saved in %.2f seconds' % save_time_metagraph)
summary = tf.Summary()
# pylint: disable=maybe-no-member
summary.value.add(tag='time/save_variables', simple_value=save_time_variables)
summary.value.add(tag='time/save_metagraph', simple_value=save_time_metagraph)
summary_writer.add_summary(summary, step)
"""
save ckpt to pb file
"""
def save_ckpt_to_pb(sess, model_dir):
load_model(model_dir)
# Retrieve the protobuf graph definition and fix the batch norm nodes
input_graph_def = sess.graph.as_graph_def()
# Freeze the graph def
output_graph_def = freeze_graph_def(sess, input_graph_def, 'embeddings')
output_graph = os.path.join(model_dir, "model.pb")
with tf.gfile.GFile(output_graph, 'wb') as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph: %s" % (len(output_graph_def.node), output_graph))
def freeze_graph_def(sess, input_graph_def, output_node_names):
for node in input_graph_def.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in xrange(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
elif node.op == 'AssignAdd':
node.op = 'Add'
if 'use_locking' in node.attr: del node.attr['use_locking']
# Get the list of important nodes
whitelist_names = []
for node in input_graph_def.node:
if (node.name.startswith('InceptionResnetV1') or node.name.startswith('embeddings') or
node.name.startswith('phase_train') or node.name.startswith('Bottleneck') or node.name.startswith(
'Logits')):
whitelist_names.append(node.name)
# Replace all the variables in the graph with constants of the same values
output_graph_def = tf.graph_util.convert_variables_to_constants(
sess, input_graph_def, output_node_names.split(","),
variable_names_whitelist=whitelist_names)
return output_graph_def
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment