Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
print trainable parameters by layer
from prettytable import PrettyTable
def count_parameters(model):
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
param = parameter.numel()
table.add_row([name, param])
total_params+=param
print(table)
print(f"Total Trainable Params: {total_params}")
return total_params
@christinakim

This comment has been minimized.

Copy link
Owner Author

@christinakim christinakim commented Nov 6, 2020

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.