Skip to content

Instantly share code, notes, and snippets.

@CasiaFan
Created March 28, 2019 03:22
Show Gist options
  • Save CasiaFan/de2772aeaa100ecb01145b37d4517b22 to your computer and use it in GitHub Desktop.
Save CasiaFan/de2772aeaa100ecb01145b37d4517b22 to your computer and use it in GitHub Desktop.
Measure model flops in tensorflow
import tensorflow as tf
def load_pb(pb_model):
with tf.gfile.GFile(pb_model, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
return graph
def estimate_flops(pb_model):
graph = load_pb(pb_model)
with graph.as_default():
# placeholder input would result in incomplete shape. So replace it with constant during model frozen.
flops = tf.profiler.profile(graph, options=tf.profiler.ProfileOptionBuilder.float_operation())
print('Model {} needs {} FLOPS after freezing'.format(pb_model, flops.total_float_ops))
model = "frozen_inference_graph.pb"
estimate_flops(model)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment