print trainable parameters by layer
from prettytable import PrettyTable | |
def count_parameters(model): | |
table = PrettyTable(["Modules", "Parameters"]) | |
total_params = 0 | |
for name, parameter in model.named_parameters(): | |
if not parameter.requires_grad: continue | |
param = parameter.numel() | |
table.add_row([name, param]) | |
total_params+=param | |
print(table) | |
print(f"Total Trainable Params: {total_params}") | |
return total_params |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This comment has been minimized.