Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
print trainable parameters by layer
from prettytable import PrettyTable
def count_parameters(model):
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
param = parameter.numel()
table.add_row([name, param])
print(f"Total Trainable Params: {total_params}")
return total_params

This comment has been minimized.

Copy link
Owner Author

@christinakim christinakim commented Nov 6, 2020


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.