Skip to content

Instantly share code, notes, and snippets.

@sergeyprokudin
Last active October 24, 2019 13:36
Show Gist options
  • Save sergeyprokudin/429c61e6536f5af5d9b0e36c660b3ae9 to your computer and use it in GitHub Desktop.
Save sergeyprokudin/429c61e6536f5af5d9b0e36c660b3ae9 to your computer and use it in GitHub Desktop.
Count trainable parameters and FLOPs per sample of a Keras model
import numpy as np
def count_conv_params_flops(conv_layer, verbose=1):
# out shape is n_cells_dim1 * (n_cells_dim2 * n_cells_dim3)
out_shape = conv_layer.output.shape.as_list()
n_cells_total = np.prod(out_shape[1:-1])
n_conv_params_total = conv_layer.count_params()
conv_flops = 2 * n_conv_params_total * n_cells_total
if verbose:
print("layer %s params: %s" % (conv_layer.name, "{:,}".format(n_conv_params_total)))
print("layer %s flops: %s" % (conv_layer.name, "{:,}".format(conv_flops)))
return n_conv_params_total, conv_flops
def count_dense_params_flops(dense_layer, verbose=1):
# out shape is n_cells_dim1 * (n_cells_dim2 * n_cells_dim3)
out_shape = dense_layer.output.shape.as_list()
n_cells_total = np.prod(out_shape[1:-1])
n_dense_params_total = dense_layer.count_params()
dense_flops = 2 * n_dense_params_total
if verbose:
print("layer %s params: %s" % (dense_layer.name, "{:,}".format(n_dense_params_total)))
print("layer %s flops: %s" % (dense_layer.name, "{:,}".format(dense_flops)))
return n_dense_params_total, dense_flops
def count_model_params_flops(model):
total_params = 0
total_flops = 0
model_layers = model.layers
for layer in model_layers:
if any(conv_type in str(type(layer)) for conv_type in ['Conv1D', 'Conv2D', 'Conv3D']):
params, flops = count_conv_params_flops(layer)
total_params += params
total_flops += flops
elif 'Dense' in str(type(layer)):
params, flops = count_dense_params_flops(layer)
total_params += params
total_flops += flops
else:
print("warning:: skippring layer: %s" % str(layer))
print("total params (%s) : %s" % (model.name, "{:,}".format(total_params)))
print("total flops (%s) : %s" % (model.name, "{:,}".format(total_flops)))
return total_params, total_flops
@forcefulowl
Copy link

its wrong bro

@sergeyprokudin
Copy link
Author

sergeyprokudin commented Jun 21, 2019

its wrong bro

Indeed, previous version worked only for dense layers. This one should work for all convs and dense and give warning for all other layers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment