Created
June 1, 2021 19:41
-
-
Save EckoTan0804/3e45221f75828ffaca74cf931bcb8e8c to your computer and use it in GitHub Desktop.
Summary of model's parameters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from prettytable import PrettyTable | |
def count_parameters(model): | |
table = PrettyTable(["Module", "Parameters"]) | |
# Assume the model consists of three main components: backbone, transformer, head | |
module_param_dict = { | |
"backbone": 0, | |
"transformer": 0, | |
"head": 0 | |
} | |
for name, parameter in model.named_parameters(): | |
if parameter.requires_grad: | |
num_param = parameter.numel() | |
# Handle different cases | |
if "global_encoder" in name: | |
module_param_dict["transformer"] += num_param | |
elif "deconv_layers" in name or "final_layer" in name: | |
module_param_dict["head"] += num_param | |
else: | |
module_param_dict["backbone"] += num_param | |
total_params = 0 | |
for name, num_param in module_param_dict.items(): | |
table.add_row([name, num_param]) | |
total_params += num_param | |
print(table) | |
print(f"Total number of trainable Parameters: {total_params}") | |
return total_params |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment