Skip to content

Instantly share code, notes, and snippets.

@scheckmedia
Last active June 29, 2023 12:26
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save scheckmedia/cadc5eb3d74ed57a4f3d78011a9f6f7c to your computer and use it in GitHub Desktop.
Save scheckmedia/cadc5eb3d74ed57a4f3d78011a9f6f7c to your computer and use it in GitHub Desktop.
Using tf 2.0 api to calculate total flops and parameters
import tensorflow as tf
session = tf.compat.v1.Session()
graph = tf.compat.v1.get_default_graph()
with graph.as_default():
with session.as_default():
base_model = tf.keras.applications.MobileNetV2(weights='imagenet', input_tensor=tf.compat.v1.placeholder('float32', shape=(1,224,224,3)))
run_meta = tf.compat.v1.RunMetadata()
opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
# We use the Keras session graph in the call to the profiler.
flops = tf.compat.v1.profiler.profile(graph=graph,
run_meta=run_meta, cmd='op', options=opts)
opts = tf.compat.v1.profiler.ProfileOptionBuilder.trainable_variables_parameter()
params = tf.compat.v1.profiler.profile(graph, run_meta=run_meta, cmd='op', options=opts)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment