Last active
September 19, 2020 16:39
-
-
Save CookieBox26/90b55c0815b3f77ab1c566f5d73bd185 to your computer and use it in GitHub Desktop.
学習済みBERTモデルのパラメータ数を数える.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import BertModel | |
def main(): | |
# 利用する学習済みBERTモデルの名前を指定する. | |
model_name = 'bert-large-cased' | |
# 学習済みモデルのインスタンスを生成する. | |
model = BertModel.from_pretrained( | |
pretrained_model_name_or_path=model_name, | |
) | |
# ブロックごとにブロック内の名前付きパラメータとパラメータ数を表示する. | |
# 2020/09/20 追記: BertModel に直接 named_parameters() や parameters() をよべるので | |
# 合計値のみに興味があるなら sum([param.numel() for param in model.parameters()]) でよい. | |
n = 0 | |
blocks = [ | |
('モデルの埋め込み層', model.embeddings), | |
('モデルのエンコーダ層', model.encoder), | |
('モデルのプーラー層', model.pooler), | |
] | |
for block in blocks: | |
print('-'*10, block[0], '-'*10) | |
if block[0] == 'モデルのエンコーダ層': # エンコーダ層は大きいので layer ごとにまとめる. | |
for i, layer in enumerate(block[1].layer): | |
n_layer = 0 | |
for name, param in layer.named_parameters(): | |
n_ = param.numel() | |
if i == 0: # 最初の1層だけ詳細に表示する. | |
print(name, n_) | |
n_layer += n_ | |
print(f'エンコーダ層内の{i}層目計', n_layer) | |
n += n_layer | |
else: | |
for name, param in block[1].named_parameters(): | |
n_ = param.numel() | |
print(name, n_) | |
n += n_ | |
print('='*10, 'パラメータ数', '='*10) | |
print(n) | |
if __name__ == '__main__': | |
main() |
Author
CookieBox26
commented
Sep 18, 2020
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment