Skip to content

Instantly share code, notes, and snippets.

@CookieBox26
Last active September 19, 2020 16:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save CookieBox26/90b55c0815b3f77ab1c566f5d73bd185 to your computer and use it in GitHub Desktop.
Save CookieBox26/90b55c0815b3f77ab1c566f5d73bd185 to your computer and use it in GitHub Desktop.
学習済みBERTモデルのパラメータ数を数える.
from transformers import BertModel
def main():
# 利用する学習済みBERTモデルの名前を指定する.
model_name = 'bert-large-cased'
# 学習済みモデルのインスタンスを生成する.
model = BertModel.from_pretrained(
pretrained_model_name_or_path=model_name,
)
# ブロックごとにブロック内の名前付きパラメータとパラメータ数を表示する.
# 2020/09/20 追記: BertModel に直接 named_parameters() や parameters() をよべるので
# 合計値のみに興味があるなら sum([param.numel() for param in model.parameters()]) でよい.
n = 0
blocks = [
('モデルの埋め込み層', model.embeddings),
('モデルのエンコーダ層', model.encoder),
('モデルのプーラー層', model.pooler),
]
for block in blocks:
print('-'*10, block[0], '-'*10)
if block[0] == 'モデルのエンコーダ層': # エンコーダ層は大きいので layer ごとにまとめる.
for i, layer in enumerate(block[1].layer):
n_layer = 0
for name, param in layer.named_parameters():
n_ = param.numel()
if i == 0: # 最初の1層だけ詳細に表示する.
print(name, n_)
n_layer += n_
print(f'エンコーダ層内の{i}層目計', n_layer)
n += n_layer
else:
for name, param in block[1].named_parameters():
n_ = param.numel()
print(name, n_)
n += n_
print('='*10, 'パラメータ数', '='*10)
print(n)
if __name__ == '__main__':
main()
@CookieBox26
Copy link
Author

----------- モデルの埋め込み層 ----------
word_embeddings.weight 29691904
position_embeddings.weight 524288
token_type_embeddings.weight 2048
LayerNorm.weight 1024
LayerNorm.bias 1024
---------- モデルのエンコーダ層 ----------
attention.self.query.weight 1048576
attention.self.query.bias 1024
attention.self.key.weight 1048576
attention.self.key.bias 1024
attention.self.value.weight 1048576
attention.self.value.bias 1024
attention.output.dense.weight 1048576
attention.output.dense.bias 1024
attention.output.LayerNorm.weight 1024
attention.output.LayerNorm.bias 1024
intermediate.dense.weight 4194304
intermediate.dense.bias 4096
output.dense.weight 4194304
output.dense.bias 1024
output.LayerNorm.weight 1024
output.LayerNorm.bias 1024
エンコーダ層内の0層目計 12596224
エンコーダ層内の1層目計 12596224
エンコーダ層内の2層目計 12596224
エンコーダ層内の3層目計 12596224
エンコーダ層内の4層目計 12596224
エンコーダ層内の5層目計 12596224
エンコーダ層内の6層目計 12596224
エンコーダ層内の7層目計 12596224
エンコーダ層内の8層目計 12596224
エンコーダ層内の9層目計 12596224
エンコーダ層内の10層目計 12596224
エンコーダ層内の11層目計 12596224
エンコーダ層内の12層目計 12596224
エンコーダ層内の13層目計 12596224
エンコーダ層内の14層目計 12596224
エンコーダ層内の15層目計 12596224
エンコーダ層内の16層目計 12596224
エンコーダ層内の17層目計 12596224
エンコーダ層内の18層目計 12596224
エンコーダ層内の19層目計 12596224
エンコーダ層内の20層目計 12596224
エンコーダ層内の21層目計 12596224
エンコーダ層内の22層目計 12596224
エンコーダ層内の23層目計 12596224
---------- モデルのプーラー層 ----------
dense.weight 1048576
dense.bias 1024
========== パラメータ数 ==========
333579264

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment