Skip to content

Instantly share code, notes, and snippets.

@tonyreina
Created November 5, 2019 23:30
Show Gist options
  • Save tonyreina/89a284b7cb6441a4afc3a4bbefd05199 to your computer and use it in GitHub Desktop.
Save tonyreina/89a284b7cb6441a4afc3a4bbefd05199 to your computer and use it in GitHub Desktop.
Summarize TensorFlow Graph for Inputs and Outputs
import argparse
import tensorflow as tf
import os
import sys
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
unlikely_output_types = ['Const', 'Assign', 'NoOp', 'Placeholder', 'Assert']
def dump_for_tensorboard(graph_def: tf.GraphDef, logdir: str):
pass
try:
# TODO: graph_def is a deprecated argument, use graph instead
print('Writing an event file for the tensorboard...')
with tf.summary.FileWriter(logdir=logdir, graph_def=graph_def) as writer:
writer.flush()
print('Done writing an event file.')
except Exception as err:
raise Error('Cannot write an event file for the tensorboard to directory "{}". ' +
refer_to_faq_msg(36), logdir) from err
def children(op_name: str, graph: tf.Graph):
op = graph.get_operation_by_name(op_name)
return set(op for out in op.outputs for op in out.consumers())
def summarize_graph(graph_def):
placeholders = dict()
outputs = list()
graph = tf.Graph()
with graph.as_default():
tf.import_graph_def(graph_def, name='')
for node in graph.as_graph_def().node:
if node.op == 'Placeholder':
node_dict = dict()
node_dict['type'] = tf.DType(node.attr['dtype'].type).name
node_dict['shape'] = str(tf.TensorShape(node.attr['shape'].shape)).replace(' ', '').replace('?', '-1')
placeholders[node.name] = node_dict
if len(children(node.name, graph)) == 0:
if node.op not in unlikely_output_types and node.name.split('/')[-1] not in unlikely_output_types:
outputs.append(node.name)
result = dict()
result['inputs'] = placeholders
result['outputs'] = outputs
return result
def print_summary(summary):
print('------------')
print('{} input(s) detected:'.format(len(summary['inputs'])))
for input in summary['inputs']:
print("Name: {}, type: {}, shape: {}".format(input, summary['inputs'][input]['type'],
summary['inputs'][input]['shape']))
print('------------')
print('{} output(s) detected:'.format(len(summary['outputs'])))
for output in summary['outputs']:
print('Name: %s' % output)
print('')
def main():
parser = argparse.ArgumentParser(description='Freeze saved model')
parser.add_argument('--model', type=str, help='Path to TF model folder', required=True)
parser.add_argument('--output', type=str, help='Output layer name', required=False)
parser.add_argument('--summary', type=bool, help='Summarize only', required=False)
parser.add_argument('--logs', type=bool, help='Dump logs for tensorboard', required=False)
args = parser.parse_args()
model_folder = args.model
summarize = args.summary
output = args.output
logs = args.logs
session_config = tf.ConfigProto(allow_soft_placement=True)
with tf.Session(config=session_config) as sess:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], model_folder)
[print(n.name) for n in tf.get_default_graph().as_graph_def().node]
if summarize:
summary = summarize_graph(sess.graph_def)
print_summary(summary)
else:
if not output:
print('Please provide output layer name')
return
# Freeze the graph
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph_def,
output.split(','))
# Save the frozen graph
with open('frozen.pb', 'wb') as f:
f.write(frozen_graph_def.SerializeToString())
summary = summarize_graph(frozen_graph_def)
print_summary(summary)
if logs:
dump_for_tensorboard(frozen_graph_def, 'logs')
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment