Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
print trainable parameters by layer
from prettytable import PrettyTable
def count_parameters(model):
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
param = parameter.numel()
table.add_row([name, param])
total_params+=param
print(table)
print(f"Total Trainable Params: {total_params}")
return total_params
@christinakim

This comment has been minimized.

Copy link
Owner Author

@christinakim christinakim commented Nov 6, 2020

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment