Skip to content

Instantly share code, notes, and snippets.

@zldrobit
Created October 18, 2018 02:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save zldrobit/f010526b755571d90a345a2c78a6fd23 to your computer and use it in GitHub Desktop.
Save zldrobit/f010526b755571d90a345a2c78a6fd23 to your computer and use it in GitHub Desktop.
import tensorflow as tf
import argparse
def profile(graph, cmd):
run_meta = tf.RunMetadata()
writer = tf.summary.FileWriter("./graph", graph)
writer.close()
opts = tf.profiler.ProfileOptionBuilder.float_operation()
flops = tf.profiler.profile(graph, run_meta=run_meta, cmd=cmd, options=opts)
opts = tf.profiler.ProfileOptionBuilder.trainable_variables_parameter()
params = tf.profiler.profile(graph, run_meta=run_meta, cmd=cmd, options=opts)
print("ops {:,} --- params {:,}".format(flops.total_float_ops, params.total_parameters))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Tensorflow Model Analyzer')
parser.add_argument('--meta_graph', type=str, default=None, help='meta graph path')
parser.add_argument('--graph_def', type=str, default=None, help='graph def path')
parser.add_argument('--input_tensor', type=str, default="image:0", help='graph def path')
parser.add_argument('--cmd', type=str, default='op', help='op / scope / graph / code')
args = parser.parse_args()
with tf.Graph().as_default() as g:
if args.meta_graph is not None:
saver = tf.train.import_meta_graph(args.meta_graph, clear_devices=True)
elif args.graph_def is not None and args.input_tensor is not None:
with tf.gfile.GFile(args.graph_def, "rb") as f:
restored_graph_def = tf.GraphDef()
restored_graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as g2:
tf.import_graph_def(
restored_graph_def,
input_map=None,
return_elements=None,
name=""
)
input_tensor0 = g2.get_tensor_by_name(args.input_tensor)
shape = input_tensor0.shape.as_list()
print("shape", shape)
shape[0] = 1
# shape = [1, 368, 368, 3]
input = tf.placeholder(tf.float32, shape=shape, name="input")
print("input.shape", input.shape)
tf.import_graph_def(
restored_graph_def,
input_map={args.input_tensor: input},
return_elements=None,
name=""
)
# c = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
# print(c)
# print(g.get_operations())
else:
print("Input meta_graph of graph_def.")
exit(0)
profile(g, args.cmd)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment