Skip to content

Instantly share code, notes, and snippets.

@EckoTan0804
Created June 1, 2021 19:41
Show Gist options
  • Save EckoTan0804/3e45221f75828ffaca74cf931bcb8e8c to your computer and use it in GitHub Desktop.
Save EckoTan0804/3e45221f75828ffaca74cf931bcb8e8c to your computer and use it in GitHub Desktop.
Summary of model's parameters
from prettytable import PrettyTable
def count_parameters(model):
table = PrettyTable(["Module", "Parameters"])
# Assume the model consists of three main components: backbone, transformer, head
module_param_dict = {
"backbone": 0,
"transformer": 0,
"head": 0
}
for name, parameter in model.named_parameters():
if parameter.requires_grad:
num_param = parameter.numel()
# Handle different cases
if "global_encoder" in name:
module_param_dict["transformer"] += num_param
elif "deconv_layers" in name or "final_layer" in name:
module_param_dict["head"] += num_param
else:
module_param_dict["backbone"] += num_param
total_params = 0
for name, num_param in module_param_dict.items():
table.add_row([name, num_param])
total_params += num_param
print(table)
print(f"Total number of trainable Parameters: {total_params}")
return total_params
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment