Skip to content

Instantly share code, notes, and snippets.

@hppRC
Created December 15, 2023 03:27
Show Gist options
  • Save hppRC/cec5e4308a3fd46f0650ec810742a71d to your computer and use it in GitHub Desktop.
Save hppRC/cec5e4308a3fd46f0650ec810742a71d to your computer and use it in GitHub Desktop.
モデルパラメータをわかりやすく表示するやつ
def format_param_with_unit(num_params: int) -> str:
if num_params >= 1000 * 1000 * 1000:
unit = "B"
num_params /= 1000 * 1000 * 1000
elif num_params >= 1000 * 1000:
unit = "M"
num_params /= 1000 * 1000
elif num_params >= 1000:
unit = "K"
num_params /= 1000
else:
unit = " "
return f"{num_params:.2f}{unit}".rjust(4 + 2 + 1)
def print_params(module: nn.Module):
num_training_params, num_freezed_params = 0, 0
for _, param in module.named_parameters():
if param.requires_grad:
num_training_params += param.numel()
else:
num_freezed_params += param.numel()
num_total_params = num_training_params + num_freezed_params
num_emb_params = 0
for mod in module.modules():
if isinstance(mod, nn.Embedding):
for p in mod.parameters():
num_emb_params += p.numel()
print("=" * 80)
print("Params summary")
params_str = map(
format_param_with_unit,
(
num_total_params,
num_training_params,
num_freezed_params,
num_emb_params,
num_total_params - num_emb_params,
),
)
names = ["total", "trainable", "freezed", "embedding", "total w/o emb"]
max_len = max(len(n) for n in names)
for name, param_str in zip(names, params_str, strict=True):
print(f"{name.ljust(max_len)}:\t{param_str}")
print("=" * 80)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment