Skip to content

Instantly share code, notes, and snippets.

@moodoki
Last active February 14, 2023 05:58
Show Gist options
  • Save moodoki/e37a85fb0258b045c005ca3db9cbc7f6 to your computer and use it in GitHub Desktop.
Save moodoki/e37a85fb0258b045c005ca3db9cbc7f6 to your computer and use it in GitHub Desktop.
Freeze and export Tensorflow graph from checkpoint files
import os, argparse
import tensorflow as tf
from tensorflow.python.framework import graph_util
dir = os.path.dirname(os.path.realpath(__file__))
def freeze_graph(model_folder, output_nodes='y_hat',
output_filename='frozen-graph.pb',
rename_outputs=None):
#Load checkpoint
checkpoint = tf.train.get_checkpoint_state(model_folder)
input_checkpoint = checkpoint.model_checkpoint_path
output_graph = output_filename
#Devices should be cleared to allow Tensorflow to control placement of
#graph when loading on different machines
saver = tf.train.import_meta_graph(input_checkpoint + '.meta',
clear_devices=True)
graph = tf.get_default_graph()
onames = output_nodes.split(',')
#https://stackoverflow.com/a/34399966/4190475
if rename_outputs is not None:
nnames = rename_outputs.split(',')
with graph.as_default():
for o, n in zip(onames, nnames):
_out = tf.identity(graph.get_tensor_by_name(o+':0'), name=n)
onames=nnames
input_graph_def = graph.as_graph_def()
# fix batch norm nodes
for node in input_graph_def.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in xrange(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
with tf.Session(graph=graph) as sess:
saver.restore(sess, input_checkpoint)
# In production, graph weights no longer need to be updated
# graph_util provides utility to change all variables to constants
output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def,
onames # unrelated nodes will be discarded
)
# Serialize and write to file
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Prune and freeze weights from checkpoints into production models')
parser.add_argument("--checkpoint_path",
default='ckpt',
type=str, help="Path to checkpoint files")
parser.add_argument("--output_nodes",
default='y_hat',
type=str, help="Names of output node, comma seperated")
parser.add_argument("--output_graph",
default='frozen-graph.pb',
type=str, help="Output graph filename")
parser.add_argument("--rename_outputs",
default=None,
type=str, help="Rename output nodes for better \
readability in production graph, to be specified in \
the same order as output_nodes")
args = parser.parse_args()
freeze_graph(args.checkpoint_path, args.output_nodes, args.output_graph, args.rename_outputs)
@charmby
Copy link

charmby commented Jun 12, 2020

AttributeError: 'NoneType' object has no attribute 'model_checkpoint_path'

@moodoki
Copy link
Author

moodoki commented Jun 16, 2020

Wow, somehow I didn't get any notifications on the comments made on this code snippet until the last comment by @charmby. Thanks, everyone that helped answer some queries by others :).

Tensorflow has evolved quite significantly since I shared this implementation and using tf.SavedModel might be a far easier approach for new code.

For working with old code/models, this script might still be useful. So let me answer some queries if they were not taken care of by others :)

@lihan, @selcouthlyBlue I did this mainly as it saves everything in a single file and makes it much easier for distributing. There should be some associated inference performance benefits as all model parameters are converted to constants instead of variables. The downside is that this saved model is no longer fine-tunable.

@charmby That error is due to the script not being able to find the model files in the path specified. The model path parameter should be the folder containing all the checkpoint files and not the .chkpt file itself.

@elham1992 This error is due to the graph not having a y_hat node present. You may want to check if the output node name of the model that you are using.

@elham1992, @soufianesabiri, @achalshah20 Apart from using summarize_graph as mentioned by others, if you have access to the code that was used to build the pre-trained graph, look at it and get it to print the output variable name. All TensorFlow graphs should work, slim isn't necessary. As a good practice, you may want to name important nodes by using the name='' parameter when creating the graph, if this parameter isn't specified, TensorFlow as a default naming convention that appends a running count to each node of the same type, e.g. conv_0, conv_1, etc. Otherwise, there's also this tool that allows you to explore graphs in many different formats. https://github.com/lutzroeder/netron

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment