Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
print trainable parameters by layer
from prettytable import PrettyTable
def count_parameters(model):
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
param = parameter.numel()
table.add_row([name, param])
print(f"Total Trainable Params: {total_params}")
return total_params

This comment has been minimized.

Copy link
Owner Author

@christinakim christinakim commented Nov 6, 2020


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment