Skip to content

Instantly share code, notes, and snippets.

@previtus
Created October 13, 2017 00:51
Show Gist options
  • Save previtus/a6a5153352c2fbaa8bfb88fffb835e73 to your computer and use it in GitHub Desktop.
Save previtus/a6a5153352c2fbaa8bfb88fffb835e73 to your computer and use it in GitHub Desktop.
def short_summary(model):
from keras import backend as K
for layer in model.layers:
trainable_count = int( np.sum([K.count_params(p) for p in set(layer.trainable_weights)]))
non_trainable_count = int( np.sum([K.count_params(p) for p in set(layer.non_trainable_weights)]))
if trainable_count == 0 and non_trainable_count == 0:
print '{:<10}[{:<10}]: {:<20} => {:<20}'.format(layer.name, layer.__class__.__name__, layer.input_shape,layer.output_shape)
else:
print '{:<10}[{:<10}]: {:<20} => {:<20}, with {} trainable + {} nontrainable'.format(layer.name, layer.__class__.__name__, layer.input_shape, layer.output_shape, trainable_count, non_trainable_count)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment