Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@christinakim
Last active November 6, 2020 16:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save christinakim/26c5ae3b22eb599b4fb5575918ff7a3b to your computer and use it in GitHub Desktop.
Save christinakim/26c5ae3b22eb599b4fb5575918ff7a3b to your computer and use it in GitHub Desktop.
print trainable parameters by layer
from prettytable import PrettyTable
def count_parameters(model):
table = PrettyTable(["Modules", "Parameters"])
total_params = 0
for name, parameter in model.named_parameters():
if not parameter.requires_grad: continue
param = parameter.numel()
table.add_row([name, param])
total_params+=param
print(table)
print(f"Total Trainable Params: {total_params}")
return total_params
@christinakim
Copy link
Author

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment