Skip to content

Instantly share code, notes, and snippets.

@05jd
Created August 26, 2018 15:10
Show Gist options
  • Save 05jd/926ce36dc162c845063749f770b26e85 to your computer and use it in GitHub Desktop.
Save 05jd/926ce36dc162c845063749f770b26e85 to your computer and use it in GitHub Desktop.
Counting FLOPs of a Keras model
"""Code Snippet for counting FLOPs of a model.
Not final version, it will be updated to improve the usability.
"""
import os.path
import tempfile
import tensorflow as tf
from tensorflow.python.keras import Model, Sequential
def count_flops(model):
""" Count flops of a keras model
# Args.
model: Model,
# Returns
int, FLOPs of a model
# Raises
TypeError, if a model is not an instance of Sequence or Model
"""
if not isinstance(model, (Sequential, Model)):
raise TypeError(
'Model is expected to be an instance of Sequential or Model, '
'but got %s' % type(model))
output_op_names = [_out_tensor.op.name for _out_tensor in model.outputs]
sess = tf.keras.backend.get_session()
frozen_graph_def = tf.graph_util.convert_variables_to_constants(
sess, sess.graph.as_graph_def(), output_op_names)
with tempfile.TemporaryDirectory() as tmpdir:
graph_file = os.path.join(os.path.join(tmpdir, 'graph.pb'))
with tf.gfile.GFile(graph_file, "wb") as f:
f.write(frozen_graph_def.SerializeToString())
with tf.gfile.GFile(graph_file, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as new_graph:
tf.import_graph_def(graph_def, name='')
tfprof_opts = tf.profiler.ProfileOptionBuilder.float_operation()
flops = tf.profiler.profile(new_graph, options=tfprof_opts)
writer = tf.summary.FileWriter('gg', graph=new_graph)
writer.flush()
return flops
if __name__ == '__main__':
vgg = tf.keras.applications.vgg16.VGG16(
include_top=True, weights=None,
input_tensor=tf.keras.Input(batch_shape=(1, 224, 224, 3)))
flops = count_flops(vgg)
print(flops)
@maguscl
Copy link

maguscl commented Jun 15, 2020

Hi!
Line 31 must be: sess = tf.compat.v1.keras.backend.get_session()

Thxs for sharing the code. Very useful

@05jd
Copy link
Author

05jd commented Jul 8, 2020

@maguscl I'm happy to hear that :) And thanks for pointing out this version compatibility issue of this gist. This gist was written on tf-1.6.
But, TF keeps changing too fast and thus I have no plan to update this gist to fit in the current TF version. Please refer this gist as just an old sample :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment