Skip to content

Instantly share code, notes, and snippets.

@lukmanr
Last active September 23, 2020 05:56
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lukmanr/a4b89e8aace0f14d15663f9c0f661d3d to your computer and use it in GitHub Desktop.
Save lukmanr/a4b89e8aace0f14d15663f9c0f661d3d to your computer and use it in GitHub Desktop.
TF Model Optimization 7
from tensorflow.tools.graph_transforms import TransformGraph
def get_graph_def_from_file(graph_filepath):
with ops.Graph().as_default():
with tf.gfile.GFile(graph_filepath, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
return graph_def
def optimize_graph(model_dir, graph_filename, transforms, output_node):
input_names = []
output_names = [output_node]
if graph_filename is None:
graph_def = get_graph_def_from_saved_model(model_dir)
else:
graph_def = get_graph_def_from_file(os.path.join(model_dir, graph_filename))
optimized_graph_def = TransformGraph(
graph_def,
input_names,
output_names,
transforms)
tf.train.write_graph(optimized_graph_def,
logdir=model_dir,
as_text=False,
name='optimized_model.pb')
print('Graph optimized!')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment