Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save CookieBox26/f0cd7b654970e8765cc30cd083cfac4b to your computer and use it in GitHub Desktop.
Save CookieBox26/f0cd7b654970e8765cc30cd083cfac4b to your computer and use it in GitHub Desktop.
BERTが何をしているかを掘り下げる
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# BERTが何をしているかを掘り下げる\n",
"\n",
"BERTが何をしているのかわからなかったので、何をしているかを掘り下げていっただけです。私の誤りは私に帰属します。お気付きの点がありましたらコメント等でご指摘ください。\n",
"\n",
"\n",
"### 目次\n",
"\n",
"```\n",
"◆ (準備)トークナイザとモデルを生成する.\n",
"◆ (準備)適当な文章をトークナイズする.\n",
"◆ モデルに流して特徴ベクトルにする.\n",
"◆ 埋め込み層とエンコーダ層を段階的に適用する.\n",
"◆ 埋め込み層についてもっと詳しくみる.\n",
"◆ エンコーダ層についてもっと詳しくみる.\n",
"◆ エンコーダ層内の0層目についてもっと詳しくみる.\n",
"◆ エンコーダ層内の0層目のセルフアテンション層についてもっと詳しくみる.\n",
"```\n",
"\n",
"### まとめ\n",
"BERTが何をしていたかの掘り下げをまとめると以下です。 \n",
"※ 次元数や層数やヘッド数は bert-large-cased のものです。\n",
"\n",
"```\n",
"モデル:\n",
"  埋め込み層:\n",
"    文章内の各トークンの単語を1024次元に埋め込む.\n",
"    文章内の各トークンの位置を1024次元に埋め込む.\n",
"    文章内の各トークンのタイプを1024次元に埋め込む(※ 今回はすべて同じタイプ).\n",
"    3つの埋め込みベクトルを足す.\n",
"    正規化して1024次元の特徴ベクトルの列にする.\n",
"  エンコーダ層:\n",
"    エンコーダ層内の0層目:\n",
"      セルフアテンション層:\n",
"        マルチヘッドアテンション:\n",
"          各特徴ベクトルを64次元に写像する(Q). ※ これを16ヘッド分やる.\n",
"          各特徴ベクトルを64次元に写像する(K). ※ これを16ヘッド分やる.\n",
"          各特徴ベクトルを64次元に写像する(V). ※ これを16ヘッド分やる.\n",
"          softmax(Q・K/√64)・V を計算する. ※ これを16ヘッド分やる.\n",
"          ここまでで各トークンが64次元の特徴になる. ※ これが16ヘッド分ある.\n",
"          16ヘッドの結果をconcatする.\n",
"          ここまでで各トークンが1024次元の特徴になる.\n",
"        各トークンごとに1024次元に全結合して正規化する.\n",
"      各トークンごとに4096次元に全結合する.\n",
"      各トークンごとに1024次元に全結合して正規化する.\n",
"    エンコーダ層内の1層目:\n",
"    エンコーダ層内の2層目:\n",
"    (中略)\n",
"    エンコーダ層内の23層目:\n",
"      (中略)\n",
"      各トークンごとに1024次元に全結合して正規化する.\n",
"\n",
"→ よって,最終的に各トークンが1024次元の特徴ベクトルになる.\n",
"```\n",
"\n",
"### 参考文献\n",
"- [transformers/modeling_bert.py at v3.1.0・huggingface/transformers](https://github.com/huggingface/transformers/blob/v3.1.0/src/transformers/modeling_bert.py)\n",
" - transformers.BertModel のソースコード。\n",
"- [[1706.03762] Attention Is All You Need](https://arxiv.org/abs/1706.03762)\n",
" - 「セルフアテンション層のマルチヘッドアテンション」(BertSelfAttention クラス)の処理が 3.2.2 節 Multi-Head Attention。\n",
"- [ML/test_bert_model.py at 342e0c0a3b60660ea5b2e6a2f1ada0f208663fda・CookieBox26/ML](https://github.com/CookieBox26/ML/blob/342e0c0a3b60660ea5b2e6a2f1ada0f208663fda/tests/transformers/test_bert_model.py#L53-L85)\n",
" - BertModel を構成する層の実体を調べてテスト風にしたもの(但し系列ラベリング用の層をかぶせたモデルでやっているので、このノートでいう model は model.bert に対応する)。これは transformers.BertModel のソースコードをみて実体が何かだけをかいつまんでかいた。\n",
"- [学習済みBERTモデルのパラメータ数を数える.](https://gist.github.com/CookieBox26/90b55c0815b3f77ab1c566f5d73bd185)\n",
" - BertModel のパラメータ数を集計したもの。\n",
" - [この記事](https://cookie-box.hatenablog.com/entry/2020/09/18/235941)にもう少しフォントが大きい出力結果がある。\n",
" - 今回のノートで何をしているのか追ったのと整合性がある(当然)。"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ (準備)トークナイザとモデルを生成する.\n"
]
}
],
"source": [
"import torch\n",
"from transformers import BertTokenizer, BertModel\n",
"\n",
"\n",
"print('◆ (準備)トークナイザとモデルを生成する.')\n",
"\n",
"# 利用する学習済みBERTモデルの名前を指定する.\n",
"model_name = 'bert-large-cased'\n",
"\n",
"# 学習済みモデルに対応したトークナイザを生成する.\n",
"tokenizer = BertTokenizer.from_pretrained(\n",
" pretrained_model_name_or_path=model_name,\n",
")\n",
"\n",
"# 学習済みモデルを生成する.\n",
"model = BertModel.from_pretrained(\n",
" pretrained_model_name_or_path=model_name,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ (準備)適当な文章をトークナイズする.\n",
"\n",
"The Empire State Building officially opened on May 1, 1931.\n",
"\n",
"◇ トークナイズ後(トークン数 14)\n",
"\n",
"トークン ID \n",
"------------------\n",
"[CLS] 101\n",
"The 1109\n",
"Empire 2813\n",
"State 1426\n",
"Building 4334\n",
"officially 3184\n",
"opened 1533\n",
"on 1113\n",
"May 1318\n",
"1 122\n",
", 117\n",
"1931 3916\n",
". 119\n",
"[SEP] 102\n"
]
}
],
"source": [
"print('◆ (準備)適当な文章をトークナイズする.')\n",
"sentence = 'The Empire State Building officially opened on May 1, 1931.'\n",
"print('\\n' + sentence)\n",
"\n",
"tokens = tokenizer.tokenize(sentence)\n",
"tokens = [tokenizer.cls_token] + tokens + [tokenizer.sep_token]\n",
"ids = tokenizer.convert_tokens_to_ids(tokens)\n",
"\n",
"print(f'\\n◇ トークナイズ後(トークン数 {len(tokens)})')\n",
"print('\\nトークン ID ')\n",
"print('------------------')\n",
"for token, id_ in zip(tokens, ids):\n",
" print(token.ljust(11), str(id_))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ モデルに流して特徴ベクトルにする.\n",
"\n",
"◇ モデルの入力 torch.Size([1, 14])\n",
"\n",
"◇ モデルの出力 torch.Size([1, 14, 1024])\n",
"\n",
"トークン ID 特徴ベクトル(1024次元のうち最初の4次元だけ)\n",
"---------------------------------------------------------------\n",
"[CLS] 101 tensor([-0.2542, -0.1722, 0.4609, 0.1602])\n",
"The 1109 tensor([-0.6633, -0.6881, 0.1078, -0.1377])\n",
"Empire 2813 tensor([-0.5705, 0.6653, -0.6274, -0.3991])\n",
"State 1426 tensor([-0.2501, 0.8431, -0.2212, 0.7095])\n",
"Building 4334 tensor([-0.0373, 0.1590, 0.1410, 0.4667])\n",
"officially 3184 tensor([0.3344, 2.0546, 0.1929, 0.5689])\n",
"opened 1533 tensor([ 0.3731, 0.7722, -0.0037, 0.1575])\n",
"on 1113 tensor([-0.3517, 1.1093, 0.4453, 0.6579])\n",
"May 1318 tensor([-0.1288, 1.4521, 0.0427, 1.0977])\n",
"1 122 tensor([ 0.2981, 0.2612, -0.5292, 0.6982])\n",
", 117 tensor([-0.6542, 0.5261, 0.1120, -0.2468])\n",
"1931 3916 tensor([-0.5008, 1.9836, 0.0269, -1.3938])\n",
". 119 tensor([-0.4263, -0.1104, -0.1661, -0.7302])\n",
"[SEP] 102 tensor([-0.4687, -0.8222, 0.2919, -0.2878])\n"
]
}
],
"source": [
"print('◆ モデルに流して特徴ベクトルにする.')\n",
"inputs = torch.tensor([ids])\n",
"outputs = model(inputs)\n",
"\n",
"print('\\n◇ モデルの入力', inputs.size())\n",
"print('\\n◇ モデルの出力', outputs[0].size())\n",
"print('\\nトークン ID 特徴ベクトル(1024次元のうち最初の4次元だけ)')\n",
"print('---------------------------------------------------------------')\n",
"for token, id_, vec in zip(tokens, ids, outputs[0][0]):\n",
" print(token.ljust(11), str(id_).ljust(5), vec[:4].detach())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ 埋め込み層とエンコーダ層を段階的に適用する.\n",
"\n",
"◇ 入力 torch.Size([1, 14])\n",
"\n",
"◇ 埋め込み層適用後 torch.Size([1, 14, 1024])\n",
"\n",
"トークン ID 埋め込み層適用後(1024次元のうち最初の4次元だけ)\n",
"---------------------------------------------------------------\n",
"[CLS] 101 tensor([ 0.0009, -0.0083, -0.1185, -0.0043])\n",
"The 1109 tensor([ 0.2233, -0.0305, 0.0369, 0.1422])\n",
"Empire 2813 tensor([-0.3681, -0.4137, 0.1736, 0.1354])\n",
"State 1426 tensor([0.2597, 0.1852, 0.1932, 0.5588])\n",
"Building 4334 tensor([ 0.1633, -0.9337, 0.7921, -0.8699])\n",
"officially 3184 tensor([ 0.7798, 0.1787, 0.4088, -0.2052])\n",
"opened 1533 tensor([-0.4937, -0.1895, 0.2692, 0.1615])\n",
"on 1113 tensor([-0.2223, 0.4795, -0.5399, -0.0934])\n",
"May 1318 tensor([ 0.7725, 0.0024, -0.1039, 0.4597])\n",
"1 122 tensor([0.9964, 0.4603, 0.0769, 0.4229])\n",
", 117 tensor([-0.2260, -0.1530, -0.2846, -0.0758])\n",
"1931 3916 tensor([-0.0137, 0.2279, 0.8815, -0.4292])\n",
". 119 tensor([ 0.5142, -0.2105, -0.3577, 0.0752])\n",
"[SEP] 102 tensor([ 0.1412, -0.1851, 0.1286, 0.2245])\n",
"\n",
"◇ エンコーダ層適用後 torch.Size([1, 14, 1024])\n",
"\n",
"トークン ID 特徴ベクトル(1024次元のうち最初の4次元だけ)\n",
"---------------------------------------------------------------\n",
"[CLS] 101 tensor([-0.2542, -0.1722, 0.4609, 0.1602])\n",
"The 1109 tensor([-0.6633, -0.6881, 0.1078, -0.1377])\n",
"Empire 2813 tensor([-0.5705, 0.6653, -0.6274, -0.3991])\n",
"State 1426 tensor([-0.2501, 0.8431, -0.2212, 0.7095])\n",
"Building 4334 tensor([-0.0373, 0.1590, 0.1410, 0.4667])\n",
"officially 3184 tensor([0.3344, 2.0546, 0.1929, 0.5689])\n",
"opened 1533 tensor([ 0.3731, 0.7722, -0.0037, 0.1575])\n",
"on 1113 tensor([-0.3517, 1.1093, 0.4453, 0.6579])\n",
"May 1318 tensor([-0.1288, 1.4521, 0.0427, 1.0977])\n",
"1 122 tensor([ 0.2981, 0.2612, -0.5292, 0.6982])\n",
", 117 tensor([-0.6542, 0.5261, 0.1120, -0.2468])\n",
"1931 3916 tensor([-0.5008, 1.9836, 0.0269, -1.3938])\n",
". 119 tensor([-0.4263, -0.1104, -0.1661, -0.7302])\n",
"[SEP] 102 tensor([-0.4687, -0.8222, 0.2919, -0.2878])\n",
"\n",
"一気にモデルに流したときと同じ特徴ベクトルになっている.\n"
]
}
],
"source": [
"print('◆ 埋め込み層とエンコーダ層を段階的に適用する.')\n",
"print('\\n◇ 入力', inputs.size())\n",
"\n",
"embedding_output = model.embeddings(input_ids=inputs)\n",
"print('\\n◇ 埋め込み層適用後', embedding_output.size())\n",
"print('\\nトークン ID 埋め込み層適用後(1024次元のうち最初の4次元だけ)')\n",
"print('---------------------------------------------------------------')\n",
"for token, id_, vec in zip(tokens, ids, embedding_output[0]):\n",
" print(token.ljust(11), str(id_).ljust(5), vec[:4].detach())\n",
"\n",
"encoder_outputs = model.encoder(embedding_output)\n",
"print('\\n◇ エンコーダ層適用後', encoder_outputs[0].size())\n",
"print('\\nトークン ID 特徴ベクトル(1024次元のうち最初の4次元だけ)')\n",
"print('---------------------------------------------------------------')\n",
"for token, id_, vec in zip(tokens, ids, encoder_outputs[0][0]):\n",
" print(token.ljust(11), str(id_).ljust(5), vec[:4].detach())\n",
" \n",
"print('\\n一気にモデルに流したときと同じ特徴ベクトルになっている.')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ 埋め込み層についてもっと詳しくみる.\n",
"\n",
"◇ 入力 torch.Size([1, 14])\n",
"\n",
"◇ 単語ごと埋め込み後 torch.Size([1, 14, 1024])\n",
"\n",
"トークン ID 単語ごと埋め込み(1024次元のうち最初の4次元だけ)\n",
"---------------------------------------------------------------\n",
"[CLS] 101 tensor([-0.0024, 0.0059, -0.0358, 0.0077])\n",
"The 1109 tensor([ 0.0061, -0.0106, -0.0627, 0.0032])\n",
"Empire 2813 tensor([-0.0505, -0.0351, -0.0346, 0.0025])\n",
"State 1426 tensor([ 0.0017, 0.0109, -0.0409, 0.0304])\n",
"Building 4334 tensor([-0.0075, -0.0626, 0.0479, -0.0738])\n",
"officially 3184 tensor([ 0.0501, 0.0033, -0.0030, -0.0114])\n",
"opened 1533 tensor([-0.0405, -0.0153, -0.0146, 0.0188])\n",
"on 1113 tensor([-0.0131, 0.0245, -0.1436, -0.0012])\n",
"May 1318 tensor([ 0.0423, -0.0112, -0.0781, 0.0224])\n",
"1 122 tensor([ 0.0410, 0.0323, -0.0739, 0.0234])\n",
", 117 tensor([-0.0459, -0.0025, -0.1194, -0.0038])\n",
"1931 3916 tensor([-0.0339, 0.0141, 0.0605, -0.0344])\n",
". 119 tensor([ 0.0196, -0.0102, -0.1244, 0.0067])\n",
"[SEP] 102 tensor([-0.0079, -0.0012, -0.0408, 0.0109])\n",
"\n",
"◇ ポジション埋め込み torch.Size([1, 14, 1024])\n",
"\n",
"トークン ポジション ポジション埋め込み(1024次元のうち最初の4次元だけ)\n",
"--------------------------------------------------------------------\n",
"[CLS] 0 tensor([ 0.0090, -0.0025, -0.0563, 0.0070])\n",
"The 1 tensor([ 0.0045, -0.0037, 0.0358, 0.0020])\n",
"Empire 2 tensor([ 0.0179, -0.0049, 0.0310, 0.0031])\n",
"State 3 tensor([ 0.0130, -0.0100, 0.0408, 0.0050])\n",
"Building 4 tensor([ 0.0121, -0.0125, 0.0459, 0.0084])\n",
"officially 5 tensor([ 0.0038, -0.0016, 0.0394, -0.0045])\n",
"opened 6 tensor([ 0.0019, -0.0071, 0.0277, -0.0098])\n",
"on 7 tensor([-4.8045e-03, -2.2486e-05, 2.3070e-02, -5.3322e-03])\n",
"May 8 tensor([ 0.0080, -0.0025, 0.0264, 0.0039])\n",
"1 9 tensor([ 0.0261, -0.0150, 0.0529, 0.0009])\n",
", 10 tensor([ 0.0255, -0.0176, 0.0376, -0.0026])\n",
"1931 11 tensor([ 0.0232, -0.0170, 0.0392, -0.0025])\n",
". 12 tensor([ 0.0291, -0.0089, 0.0300, 0.0051])\n",
"[SEP] 13 tensor([ 0.0241, -0.0115, 0.0399, 0.0112])\n",
"\n",
"◇ トークンタイプ埋め込み torch.Size([1, 14, 1024])\n",
"\n",
"トークン トークンタイプ トークンタイプ埋め込み(1024次元のうち最初の4次元だけ)\n",
"--------------------------------------------------------------------\n",
"[CLS] 0 tensor([-0.0080, 0.0018, 0.0157, -0.0054])\n",
"The 0 tensor([-0.0080, 0.0018, 0.0157, -0.0054])\n",
"Empire 0 tensor([-0.0080, 0.0018, 0.0157, -0.0054])\n",
"State 0 tensor([-0.0080, 0.0018, 0.0157, -0.0054])\n",
"Building 0 tensor([-0.0080, 0.0018, 0.0157, -0.0054])\n",
"officially 0 tensor([-0.0080, 0.0018, 0.0157, -0.0054])\n",
"opened 0 tensor([-0.0080, 0.0018, 0.0157, -0.0054])\n",
"on 0 tensor([-0.0080, 0.0018, 0.0157, -0.0054])\n",
"May 0 tensor([-0.0080, 0.0018, 0.0157, -0.0054])\n",
"1 0 tensor([-0.0080, 0.0018, 0.0157, -0.0054])\n",
", 0 tensor([-0.0080, 0.0018, 0.0157, -0.0054])\n",
"1931 0 tensor([-0.0080, 0.0018, 0.0157, -0.0054])\n",
". 0 tensor([-0.0080, 0.0018, 0.0157, -0.0054])\n",
"[SEP] 0 tensor([-0.0080, 0.0018, 0.0157, -0.0054])\n",
"\n",
"◇ 単語ごと埋め込み + ポジション + トークンタイプ torch.Size([1, 14, 1024])\n",
"\n",
"トークン ID 単語ごと埋め込み+α(1024次元のうち最初の4次元だけ)\n",
"---------------------------------------------------------------\n",
"[CLS] 101 tensor([-0.0013, 0.0052, -0.0764, 0.0093])\n",
"The 1109 tensor([ 0.0026, -0.0126, -0.0111, -0.0002])\n",
"Empire 2813 tensor([-0.0406, -0.0383, 0.0121, 0.0002])\n",
"State 1426 tensor([0.0068, 0.0026, 0.0157, 0.0301])\n",
"Building 4334 tensor([-0.0034, -0.0733, 0.1096, -0.0707])\n",
"officially 3184 tensor([ 0.0459, 0.0035, 0.0521, -0.0213])\n",
"opened 1533 tensor([-0.0466, -0.0206, 0.0288, 0.0037])\n",
"on 1113 tensor([-0.0259, 0.0263, -0.1048, -0.0119])\n",
"May 1318 tensor([ 0.0423, -0.0119, -0.0360, 0.0209])\n",
"1 122 tensor([ 0.0591, 0.0191, -0.0053, 0.0189])\n",
", 117 tensor([-0.0283, -0.0183, -0.0661, -0.0118])\n",
"1931 3916 tensor([-0.0187, -0.0011, 0.1155, -0.0423])\n",
". 119 tensor([ 0.0407, -0.0174, -0.0787, 0.0065])\n",
"[SEP] 102 tensor([ 0.0082, -0.0110, 0.0149, 0.0167])\n",
"\n",
"◇ LayerNorm後 torch.Size([1, 14, 1024])\n",
"\n",
"トークン ID LayerNorm後(1024次元のうち最初の4次元だけ)\n",
"---------------------------------------------------------------\n",
"[CLS] 101 tensor([ 0.0009, -0.0083, -0.1185, -0.0043])\n",
"The 1109 tensor([ 0.2233, -0.0305, 0.0369, 0.1422])\n",
"Empire 2813 tensor([-0.3681, -0.4137, 0.1736, 0.1354])\n",
"State 1426 tensor([0.2597, 0.1852, 0.1932, 0.5588])\n",
"Building 4334 tensor([ 0.1633, -0.9337, 0.7921, -0.8699])\n",
"officially 3184 tensor([ 0.7798, 0.1787, 0.4088, -0.2052])\n",
"opened 1533 tensor([-0.4937, -0.1895, 0.2692, 0.1615])\n",
"on 1113 tensor([-0.2223, 0.4795, -0.5399, -0.0934])\n",
"May 1318 tensor([ 0.7725, 0.0024, -0.1039, 0.4597])\n",
"1 122 tensor([0.9964, 0.4603, 0.0769, 0.4229])\n",
", 117 tensor([-0.2260, -0.1530, -0.2846, -0.0758])\n",
"1931 3916 tensor([-0.0137, 0.2279, 0.8815, -0.4292])\n",
". 119 tensor([ 0.5142, -0.2105, -0.3577, 0.0752])\n",
"[SEP] 102 tensor([ 0.1412, -0.1851, 0.1286, 0.2245])\n",
"\n",
"◇ Dropout後 torch.Size([1, 14, 1024]) ※ 訓練モードではないので Dropout は適用されない.\n",
"\n",
"トークン ID Dropout後(1024次元のうち最初の4次元だけ)\n",
"---------------------------------------------------------------\n",
"[CLS] 101 tensor([ 0.0009, -0.0083, -0.1185, -0.0043])\n",
"The 1109 tensor([ 0.2233, -0.0305, 0.0369, 0.1422])\n",
"Empire 2813 tensor([-0.3681, -0.4137, 0.1736, 0.1354])\n",
"State 1426 tensor([0.2597, 0.1852, 0.1932, 0.5588])\n",
"Building 4334 tensor([ 0.1633, -0.9337, 0.7921, -0.8699])\n",
"officially 3184 tensor([ 0.7798, 0.1787, 0.4088, -0.2052])\n",
"opened 1533 tensor([-0.4937, -0.1895, 0.2692, 0.1615])\n",
"on 1113 tensor([-0.2223, 0.4795, -0.5399, -0.0934])\n",
"May 1318 tensor([ 0.7725, 0.0024, -0.1039, 0.4597])\n",
"1 122 tensor([0.9964, 0.4603, 0.0769, 0.4229])\n",
", 117 tensor([-0.2260, -0.1530, -0.2846, -0.0758])\n",
"1931 3916 tensor([-0.0137, 0.2279, 0.8815, -0.4292])\n",
". 119 tensor([ 0.5142, -0.2105, -0.3577, 0.0752])\n",
"[SEP] 102 tensor([ 0.1412, -0.1851, 0.1286, 0.2245])\n",
"\n",
"埋め込み層適用後と同じになっている.\n"
]
}
],
"source": [
"import torch\n",
"\n",
"\n",
"print('◆ 埋め込み層についてもっと詳しくみる.')\n",
"print('\\n◇ 入力', inputs.size())\n",
"\n",
"inputs_embeds = model.embeddings.word_embeddings(inputs)\n",
"print('\\n◇ 単語ごと埋め込み後', inputs_embeds.size())\n",
"print('\\nトークン ID 単語ごと埋め込み(1024次元のうち最初の4次元だけ)')\n",
"print('---------------------------------------------------------------')\n",
"for token, id_, vec in zip(tokens, ids, inputs_embeds[0]):\n",
" print(token.ljust(11), str(id_).ljust(5), vec[:4].detach())\n",
"\n",
"position_ids = model.embeddings.position_ids[:, :14]\n",
"position_embeddings = model.embeddings.position_embeddings(position_ids)\n",
"print('\\n◇ ポジション埋め込み', position_embeddings.size())\n",
"print('\\nトークン ポジション ポジション埋め込み(1024次元のうち最初の4次元だけ)')\n",
"print('--------------------------------------------------------------------')\n",
"for token, pos_id, vec in zip(tokens, position_ids[0], position_embeddings[0]):\n",
" print(token.ljust(11), str(pos_id.item()).ljust(11), vec[:4].detach())\n",
"\n",
"token_type_ids = torch.zeros(inputs.size(), dtype=torch.long, device=position_ids.device)\n",
"token_type_embeddings = model.embeddings.token_type_embeddings(token_type_ids)\n",
"print('\\n◇ トークンタイプ埋め込み', token_type_embeddings.size())\n",
"print('\\nトークン トークンタイプ トークンタイプ埋め込み(1024次元のうち最初の4次元だけ)')\n",
"print('--------------------------------------------------------------------')\n",
"for token, token_id, vec in zip(tokens, token_type_ids[0], token_type_embeddings[0]):\n",
" print(token.ljust(11), str(token_id.item()).ljust(15), vec[:4].detach())\n",
"\n",
"embeddings = inputs_embeds + position_embeddings + token_type_embeddings\n",
"print('\\n◇ 単語ごと埋め込み + ポジション + トークンタイプ', embeddings.size())\n",
"print('\\nトークン ID 単語ごと埋め込み+α(1024次元のうち最初の4次元だけ)')\n",
"print('---------------------------------------------------------------')\n",
"for token, id_, vec in zip(tokens, ids, embeddings[0]):\n",
" print(token.ljust(11), str(id_).ljust(5), vec[:4].detach())\n",
"\n",
"embeddings = model.embeddings.LayerNorm(embeddings)\n",
"print('\\n◇ LayerNorm後', embeddings.size())\n",
"print('\\nトークン ID LayerNorm後(1024次元のうち最初の4次元だけ)')\n",
"print('---------------------------------------------------------------')\n",
"for token, id_, vec in zip(tokens, ids, embeddings[0]):\n",
" print(token.ljust(11), str(id_).ljust(5), vec[:4].detach())\n",
"\n",
"embeddings = model.embeddings.dropout(embeddings)\n",
"print('\\n◇ Dropout後', embeddings.size(), '※ 訓練モードではないので Dropout は適用されない.')\n",
"print('\\nトークン ID Dropout後(1024次元のうち最初の4次元だけ)')\n",
"print('---------------------------------------------------------------')\n",
"for token, id_, vec in zip(tokens, ids, embeddings[0]):\n",
" print(token.ljust(11), str(id_).ljust(5), vec[:4].detach())\n",
"\n",
"print('\\n埋め込み層適用後と同じになっている.')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ エンコーダ層についてもっと詳しくみる.\n",
"\n",
"◇ 埋め込み層適用後まで torch.Size([1, 14, 1024])\n",
"\n",
"◇ エンコーダ層内の0層目適用後 torch.Size([1, 14, 1024])\n",
"\n",
"トークン ID 特徴ベクトル(1024次元のうち最初の4次元だけ)\n",
"---------------------------------------------------------------\n",
"[CLS] 101 tensor([ 0.1590, -0.0149, -0.1153, 0.0482])\n",
"The 1109 tensor([-0.1219, -0.2602, 0.0420, -0.0295])\n",
"Empire 2813 tensor([-1.1064, -0.2894, -0.4232, -0.2743])\n",
"State 1426 tensor([-1.2130, -0.0012, 0.3907, 0.1322])\n",
"Building 4334 tensor([-0.8030, -1.0331, 0.1194, -1.5317])\n",
"officially 3184 tensor([ 0.1120, 0.2129, -0.4410, -0.0081])\n",
"opened 1533 tensor([-1.4200, -0.4410, 0.2487, -0.0915])\n",
"on 1113 tensor([-0.8189, 0.6773, -0.4591, -0.2131])\n",
"May 1318 tensor([ 0.0751, -0.7804, -0.2624, 0.6698])\n",
"1 122 tensor([ 1.0526, 0.3958, -0.5038, 0.3439])\n",
", 117 tensor([-0.6380, -0.4596, -0.3208, -0.0329])\n",
"1931 3916 tensor([-0.2772, 0.7000, 0.1715, -1.0733])\n",
". 119 tensor([ 0.2465, -0.3849, -0.2037, 0.1193])\n",
"[SEP] 102 tensor([ 0.0682, -0.0919, 0.0556, 0.1185])\n",
"\n",
"◇ エンコーダ層内の1層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層内の2層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層内の3層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層内の4層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層内の5層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層内の6層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層内の7層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層内の8層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層内の9層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層内の10層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層内の11層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層内の12層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層内の13層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層内の14層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層内の15層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層内の16層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層内の17層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層内の18層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層内の19層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層内の20層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層内の21層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層内の22層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層内の23層目適用後 torch.Size([1, 14, 1024]) (途中経過略)\n",
"\n",
"◇ エンコーダ層適用後 torch.Size([1, 14, 1024])\n",
"\n",
"トークン ID 特徴ベクトル(1024次元のうち最初の4次元だけ)\n",
"---------------------------------------------------------------\n",
"[CLS] 101 tensor([-0.2542, -0.1722, 0.4609, 0.1602])\n",
"The 1109 tensor([-0.6633, -0.6881, 0.1078, -0.1377])\n",
"Empire 2813 tensor([-0.5705, 0.6653, -0.6274, -0.3991])\n",
"State 1426 tensor([-0.2501, 0.8431, -0.2212, 0.7095])\n",
"Building 4334 tensor([-0.0373, 0.1590, 0.1410, 0.4667])\n",
"officially 3184 tensor([0.3344, 2.0546, 0.1929, 0.5689])\n",
"opened 1533 tensor([ 0.3731, 0.7722, -0.0037, 0.1575])\n",
"on 1113 tensor([-0.3517, 1.1093, 0.4453, 0.6579])\n",
"May 1318 tensor([-0.1288, 1.4521, 0.0427, 1.0977])\n",
"1 122 tensor([ 0.2981, 0.2612, -0.5292, 0.6982])\n",
", 117 tensor([-0.6542, 0.5261, 0.1120, -0.2468])\n",
"1931 3916 tensor([-0.5008, 1.9836, 0.0269, -1.3938])\n",
". 119 tensor([-0.4263, -0.1104, -0.1661, -0.7302])\n",
"[SEP] 102 tensor([-0.4687, -0.8222, 0.2919, -0.2878])\n",
"\n",
"一気にモデルに流したときと同じ特徴ベクトルになっている.\n"
]
}
],
"source": [
"print('◆ エンコーダ層についてもっと詳しくみる.')\n",
"embeddings = model.embeddings(input_ids=inputs)\n",
"print('\\n◇ 埋め込み層適用後まで', embeddings.size())\n",
"\n",
"hidden_states = embeddings\n",
"for i_layer, layer in enumerate(model.encoder.layer):\n",
" layer_outputs = layer(hidden_states=hidden_states)\n",
" hidden_states = layer_outputs[0]\n",
" if i_layer == 0:\n",
" print(f'\\n◇ エンコーダ層内の{i_layer}層目適用後', hidden_states.size())\n",
" print('\\nトークン ID 特徴ベクトル(1024次元のうち最初の4次元だけ)')\n",
" print('---------------------------------------------------------------')\n",
" for token, id_, vec in zip(tokens, ids, hidden_states[0]):\n",
" print(token.ljust(11), str(id_).ljust(5), vec[:4].detach())\n",
" else:\n",
" print(f'\\n◇ エンコーダ層内の{i_layer}層目適用後', hidden_states.size(), '(途中経過略)')\n",
"\n",
"print('\\n◇ エンコーダ層適用後', layer_outputs[0].size())\n",
"print('\\nトークン ID 特徴ベクトル(1024次元のうち最初の4次元だけ)')\n",
"print('---------------------------------------------------------------')\n",
"for token, id_, vec in zip(tokens, ids, layer_outputs[0][0]):\n",
" print(token.ljust(11), str(id_).ljust(5), vec[:4].detach())\n",
"\n",
"print('\\n一気にモデルに流したときと同じ特徴ベクトルになっている.')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ エンコーダ層内の0層目についてもっと詳しくみる.\n",
"\n",
"◇ 埋め込み層適用後まで torch.Size([1, 14, 1024])\n",
"\n",
"◇ エンコーダ層内の0層目のセルフアテンション層 torch.Size([1, 14, 1024])\n",
"\n",
"トークン ID セルフアテンション(1024次元のうち最初の4次元だけ)\n",
"---------------------------------------------------------------\n",
"[CLS] 101 tensor([-2.2191, 0.5429, -1.0996, 0.0422])\n",
"The 1109 tensor([-1.0280, 0.0940, -0.2649, 0.3015])\n",
"Empire 2813 tensor([-2.0858, -0.2929, -0.9683, 0.1291])\n",
"State 1426 tensor([-0.4919, 0.4744, -0.3698, 0.8508])\n",
"Building 4334 tensor([-1.7632, -1.2738, 0.1818, -1.5367])\n",
"officially 3184 tensor([ 0.0181, 0.5400, 0.1537, -0.4500])\n",
"opened 1533 tensor([-2.6714, -0.1104, 0.0891, 0.0648])\n",
"on 1113 tensor([-2.0650, 0.9926, -1.0767, -0.1288])\n",
"May 1318 tensor([-0.3252, 0.0536, -0.4679, 0.6946])\n",
"1 122 tensor([ 0.6596, 0.5892, -0.6020, 0.8227])\n",
", 117 tensor([-1.7578, -0.4852, -1.1520, -0.2884])\n",
"1931 3916 tensor([-1.5812, 0.6358, 1.1979, -0.7883])\n",
". 119 tensor([ 0.3659, -0.6282, -1.1055, 0.1714])\n",
"[SEP] 102 tensor([ 0.0338, -0.5147, -0.5761, 0.5444])\n",
"\n",
"◇ エンコーダ層内の0層目の中間アウトプット torch.Size([1, 14, 4096])\n",
"\n",
"トークン ID セルフアテンション(4096次元のうち最初の4次元だけ)\n",
"---------------------------------------------------------------\n",
"[CLS] 101 tensor([ 1.8646, -0.1123, -0.1659, 0.7346])\n",
"The 1109 tensor([-5.9923e-04, -1.4980e-07, -4.1618e-02, -1.6558e-01])\n",
"Empire 2813 tensor([-7.5820e-03, -7.6221e-07, -1.0115e-02, -9.0719e-02])\n",
"State 1426 tensor([-0.0057, -0.0007, -0.0001, -0.0368])\n",
"Building 4334 tensor([-6.6138e-06, -1.8558e-02, -1.5911e-01, -9.8101e-02])\n",
"officially 3184 tensor([-5.7244e-02, -1.4014e-07, -7.6652e-05, -5.6438e-02])\n",
"opened 1533 tensor([-6.3999e-06, -3.6866e-08, -1.5436e-03, -3.0976e-03])\n",
"on 1113 tensor([-3.7424e-03, -3.0094e-08, -6.1249e-09, -8.7091e-02])\n",
"May 1318 tensor([-6.6016e-07, -5.5692e-08, -1.5565e-05, 4.9265e-02])\n",
"1 122 tensor([-9.2506e-03, -1.2187e-06, -6.8205e-07, -1.6503e-01])\n",
", 117 tensor([-2.3187e-03, -1.6806e-06, -2.6252e-07, -1.1637e-02])\n",
"1931 3916 tensor([-0.1572, -0.0003, -0.0002, -0.0074])\n",
". 119 tensor([-4.3688e-05, -3.6707e-04, -5.4393e-07, -1.7599e-02])\n",
"[SEP] 102 tensor([-3.0956e-04, -1.0528e-04, -5.4119e-05, -6.7550e-03])\n",
"\n",
"◇ エンコーダ層内の0層目の最終アウトプット torch.Size([1, 14, 1024])\n",
"\n",
"トークン ID セルフアテンション(1024次元のうち最初の4次元だけ)\n",
"---------------------------------------------------------------\n",
"[CLS] 101 tensor([ 0.1590, -0.0149, -0.1153, 0.0482])\n",
"The 1109 tensor([-0.1219, -0.2602, 0.0420, -0.0295])\n",
"Empire 2813 tensor([-1.1064, -0.2894, -0.4232, -0.2743])\n",
"State 1426 tensor([-1.2130, -0.0012, 0.3907, 0.1322])\n",
"Building 4334 tensor([-0.8030, -1.0331, 0.1194, -1.5317])\n",
"officially 3184 tensor([ 0.1120, 0.2129, -0.4410, -0.0081])\n",
"opened 1533 tensor([-1.4200, -0.4410, 0.2487, -0.0915])\n",
"on 1113 tensor([-0.8189, 0.6773, -0.4591, -0.2131])\n",
"May 1318 tensor([ 0.0751, -0.7804, -0.2624, 0.6698])\n",
"1 122 tensor([ 1.0526, 0.3958, -0.5038, 0.3439])\n",
", 117 tensor([-0.6380, -0.4596, -0.3208, -0.0329])\n",
"1931 3916 tensor([-0.2772, 0.7000, 0.1715, -1.0733])\n",
". 119 tensor([ 0.2465, -0.3849, -0.2037, 0.1193])\n",
"[SEP] 102 tensor([ 0.0682, -0.0919, 0.0556, 0.1185])\n",
"\n",
"エンコーダ層内の0層目適用後と同じ特徴ベクトルになっている.\n"
]
}
],
"source": [
"print('◆ エンコーダ層内の0層目についてもっと詳しくみる.')\n",
"embeddings = model.embeddings(input_ids=inputs)\n",
"print('\\n◇ 埋め込み層適用後まで', embeddings.size())\n",
"\n",
"hidden_states = embeddings\n",
"layer = model.encoder.layer[0]\n",
"\n",
"self_attention_outputs = layer.attention(hidden_states)\n",
"attention_output = self_attention_outputs[0]\n",
"print('\\n◇ エンコーダ層内の0層目のセルフアテンション層', attention_output.size())\n",
"print('\\nトークン ID セルフアテンション(1024次元のうち最初の4次元だけ)')\n",
"print('---------------------------------------------------------------')\n",
"for token, id_, vec in zip(tokens, ids, attention_output[0]):\n",
" print(token.ljust(11), str(id_).ljust(5), vec[:4].detach())\n",
"\n",
"intermediate_output = layer.intermediate(attention_output)\n",
"print('\\n◇ エンコーダ層内の0層目の中間アウトプット', intermediate_output.size())\n",
"print('\\nトークン ID セルフアテンション(4096次元のうち最初の4次元だけ)')\n",
"print('---------------------------------------------------------------')\n",
"for token, id_, vec in zip(tokens, ids, intermediate_output[0]):\n",
" print(token.ljust(11), str(id_).ljust(5), vec[:4].detach())\n",
"\n",
"layer_output = layer.output(intermediate_output, attention_output)\n",
"print('\\n◇ エンコーダ層内の0層目の最終アウトプット', layer_output.size())\n",
"print('\\nトークン ID セルフアテンション(1024次元のうち最初の4次元だけ)')\n",
"print('---------------------------------------------------------------')\n",
"for token, id_, vec in zip(tokens, ids, layer_output[0]):\n",
" print(token.ljust(11), str(id_).ljust(5), vec[:4].detach())\n",
"\n",
"print('\\nエンコーダ層内の0層目適用後と同じ特徴ベクトルになっている.')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ エンコーダ層内の0層目のセルフアテンション層についてもっと詳しくみる.\n",
"\n",
"◇ 埋め込み層適用後まで torch.Size([1, 14, 1024])\n",
"\n",
"◇ セルフアテンション層の Q(64次元に写像 × 16ヘッド) torch.Size([1, 16, 14, 64])\n",
"\n",
"◇ セルフアテンション層の K(64次元に写像 × 16ヘッド) torch.Size([1, 16, 14, 64])\n",
"\n",
"◇ セルフアテンション層の V(64次元に写像 × 16ヘッド) torch.Size([1, 16, 14, 64])\n",
"\n",
"トークン ID Q(0ヘッド目)(64次元のうち最初の4次元だけ)\n",
"---------------------------------------------------------------\n",
"[CLS] 101 tensor([ 0.0053, -0.1864, -1.5590, -0.2854])\n",
"The 1109 tensor([ 0.3943, 0.4636, 1.0551, -1.1965])\n",
"Empire 2813 tensor([-0.6055, 1.3599, -1.1306, -0.3850])\n",
"State 1426 tensor([-1.6630, 0.5093, -0.8856, -0.6479])\n",
"Building 4334 tensor([0.8084, 0.2061, 0.4351, 0.8148])\n",
"officially 3184 tensor([-1.6065, -0.3829, -0.6481, 0.2430])\n",
"opened 1533 tensor([ 0.6485, -1.7961, 0.9005, -1.7972])\n",
"on 1113 tensor([-0.6037, -0.3198, -0.9668, -1.9916])\n",
"May 1318 tensor([-2.2796, 0.2057, 1.4450, -0.9425])\n",
"1 122 tensor([-1.9763, -1.2419, 1.9547, -1.6826])\n",
", 117 tensor([-0.4010, 0.2332, 0.5817, -0.6312])\n",
"1931 3916 tensor([-1.4629, 0.6146, 0.4683, -1.0621])\n",
". 119 tensor([-1.1577, 1.1318, 0.2125, -1.7602])\n",
"[SEP] 102 tensor([-1.5064, -1.0007, 0.8054, -0.1837])\n",
"\n",
"トークン ID K(0ヘッド目)(64次元のうち最初の4次元だけ)\n",
"---------------------------------------------------------------\n",
"[CLS] 101 tensor([-0.3109, -1.0632, 0.9761, 0.3209])\n",
"The 1109 tensor([-0.4832, -0.5085, 0.4726, 0.4999])\n",
"Empire 2813 tensor([ 0.3147, -0.3506, 0.6618, -0.2500])\n",
"State 1426 tensor([ 1.9507, 0.2254, -1.0760, -0.2168])\n",
"Building 4334 tensor([-0.1862, -1.1780, -1.6273, 1.0553])\n",
"officially 3184 tensor([ 1.6586, 0.2016, -1.4912, -0.2609])\n",
"opened 1533 tensor([-0.3511, 0.4377, -1.2361, -2.0748])\n",
"on 1113 tensor([-0.2561, -0.5802, 0.3933, -1.4295])\n",
"May 1318 tensor([-0.4343, -1.3554, 1.1077, -0.7949])\n",
"1 122 tensor([-1.0242, 0.7306, -0.0188, -1.1051])\n",
", 117 tensor([-0.8041, 0.5653, 0.7031, -1.2465])\n",
"1931 3916 tensor([-0.0503, 0.2498, -0.5218, -2.9984])\n",
". 119 tensor([-0.9162, 0.2248, 0.6183, -1.5219])\n",
"[SEP] 102 tensor([-1.5568, 1.1062, -0.3203, -2.6940])\n",
"\n",
"トークン ID V(0ヘッド目)(64次元のうち最初の4次元だけ)\n",
"---------------------------------------------------------------\n",
"[CLS] 101 tensor([-0.0005, 0.2435, -0.0739, 0.0539])\n",
"The 1109 tensor([ 0.5705, 0.5025, -0.1333, 0.9561])\n",
"Empire 2813 tensor([-0.3488, -0.1480, 0.2673, -0.4913])\n",
"State 1426 tensor([ 0.3877, -1.3560, 0.6915, -0.4490])\n",
"Building 4334 tensor([ 0.7626, -0.2158, 0.7697, -0.4891])\n",
"officially 3184 tensor([-0.2618, -0.9578, 0.4074, -0.1486])\n",
"opened 1533 tensor([ 0.5452, -0.3873, 1.7566, 0.2749])\n",
"on 1113 tensor([-0.5832, 0.1165, 0.5800, 0.4236])\n",
"May 1318 tensor([ 0.7438, -1.1668, 0.4729, 0.0466])\n",
"1 122 tensor([-0.5033, 0.2378, 0.7522, 0.4560])\n",
", 117 tensor([0.1744, 0.0023, 0.0168, 0.2345])\n",
"1931 3916 tensor([-0.9027, 0.1798, 0.3355, 0.3945])\n",
". 119 tensor([ 0.2487, -0.1832, 0.1535, -0.2618])\n",
"[SEP] 102 tensor([-0.3748, 0.5896, 0.5478, 0.5616])\n",
"\n",
"◇ セルフアテンション層の Q・K(ドットプロダクトしただけ) torch.Size([1, 16, 14, 14])\n",
"\n",
"◇ セルフアテンション層の Q・K(√64で割ってソフトマックスまで) torch.Size([1, 16, 14, 14])\n",
"\n",
"トークン ID softmax(Q・K/√64)(0ヘッド目)(14次元のうち最初の4次元だけ)\n",
"---------------------------------------------------------------\n",
"[CLS] 101 tensor([6.7831e-05, 5.0022e-02, 9.2978e-02, 7.5121e-02])\n",
"The 1109 tensor([0.1171, 0.1757, 0.3496, 0.1450])\n",
"Empire 2813 tensor([0.0483, 0.0699, 0.0244, 0.4448])\n",
"State 1426 tensor([0.0012, 0.0269, 0.0613, 0.0087])\n",
"Building 4334 tensor([0.0218, 0.0140, 0.0903, 0.4639])\n",
"officially 3184 tensor([0.0678, 0.0136, 0.0044, 0.0199])\n",
"opened 1533 tensor([0.0662, 0.0032, 0.0064, 0.0201])\n",
"on 1113 tensor([0.0085, 0.0023, 0.0010, 0.0044])\n",
"May 1318 tensor([0.0400, 0.0053, 0.0057, 0.0056])\n",
"1 122 tensor([0.0029, 0.0018, 0.0031, 0.0028])\n",
", 117 tensor([0.0731, 0.0081, 0.0047, 0.0067])\n",
"1931 3916 tensor([0.3420, 0.0017, 0.0013, 0.0018])\n",
". 119 tensor([0.0246, 0.0031, 0.0026, 0.0025])\n",
"[SEP] 102 tensor([7.5419e-01, 5.8881e-04, 1.0644e-04, 3.1355e-04])\n",
"\n",
"◇ セルフアテンション層の Q・K・V = Attention(Q, K, V) torch.Size([1, 16, 14, 64])\n",
"\n",
"◇ セルフアテンション層の Attention(Q, K, V)(転置) torch.Size([1, 14, 16, 64])\n",
"\n",
"◇ セルフアテンション層の Attention(Q, K, V)(16ヘッドをconcat) torch.Size([1, 14, 1024])\n",
"\n",
"トークン ID セルフアテンション(1024次元のうち最初の4次元だけ)\n",
"---------------------------------------------------------------\n",
"[CLS] 101 tensor([-0.0948, -0.2211, 0.5652, 0.0773])\n",
"The 1109 tensor([ 0.0378, -0.1861, 0.2771, -0.0478])\n",
"Empire 2813 tensor([ 0.4830, -0.6592, 0.6199, -0.3148])\n",
"State 1426 tensor([ 0.6607, -0.2111, 0.7154, -0.4308])\n",
"Building 4334 tensor([ 0.1230, -0.9245, 0.5564, -0.2931])\n",
"officially 3184 tensor([ 0.3682, -0.2707, 1.0510, 0.0506])\n",
"opened 1533 tensor([-0.2751, -0.3856, 0.5004, 0.1303])\n",
"on 1113 tensor([ 0.3689, -0.6057, 0.8573, 0.1807])\n",
"May 1318 tensor([-0.3338, 0.1158, 0.6251, 0.3667])\n",
"1 122 tensor([ 0.4580, -0.8426, 0.4600, 0.1110])\n",
", 117 tensor([-0.2331, 0.0511, 0.4025, 0.2881])\n",
"1931 3916 tensor([-0.0811, 0.1821, 0.2231, 0.1854])\n",
". 119 tensor([-0.2823, 0.3668, 0.4430, 0.3902])\n",
"[SEP] 102 tensor([-0.0334, 0.2529, 0.0407, 0.1034])\n",
"\n",
"◇ セルフアテンション層の最終アウトプット(全結合後) torch.Size([1, 14, 1024])\n",
"\n",
"トークン ID セルフアテンション(1024次元のうち最初の4次元だけ)\n",
"---------------------------------------------------------------\n",
"[CLS] 101 tensor([-2.2191, 0.5429, -1.0996, 0.0422])\n",
"The 1109 tensor([-1.0280, 0.0940, -0.2649, 0.3015])\n",
"Empire 2813 tensor([-2.0858, -0.2929, -0.9683, 0.1291])\n",
"State 1426 tensor([-0.4919, 0.4744, -0.3698, 0.8508])\n",
"Building 4334 tensor([-1.7632, -1.2738, 0.1818, -1.5367])\n",
"officially 3184 tensor([ 0.0181, 0.5400, 0.1537, -0.4500])\n",
"opened 1533 tensor([-2.6714, -0.1104, 0.0891, 0.0648])\n",
"on 1113 tensor([-2.0650, 0.9926, -1.0767, -0.1288])\n",
"May 1318 tensor([-0.3252, 0.0536, -0.4679, 0.6946])\n",
"1 122 tensor([ 0.6596, 0.5892, -0.6020, 0.8227])\n",
", 117 tensor([-1.7578, -0.4852, -1.1520, -0.2884])\n",
"1931 3916 tensor([-1.5812, 0.6358, 1.1979, -0.7883])\n",
". 119 tensor([ 0.3659, -0.6282, -1.1055, 0.1714])\n",
"[SEP] 102 tensor([ 0.0338, -0.5147, -0.5761, 0.5444])\n",
"\n",
"エンコーダ層内の0層目のセルフアテンション層と同じになっている.\n"
]
}
],
"source": [
"from torch import nn\n",
"import math\n",
"\n",
"\n",
"print('◆ エンコーダ層内の0層目のセルフアテンション層についてもっと詳しくみる.')\n",
"embeddings = model.embeddings(input_ids=inputs)\n",
"print('\\n◇ 埋め込み層適用後まで', embeddings.size())\n",
"\n",
"hidden_states = embeddings\n",
"attn = model.encoder.layer[0].attention\n",
"\n",
"mixed_query_layer = attn.self.query(hidden_states)\n",
"mixed_key_layer = attn.self.key(hidden_states)\n",
"mixed_value_layer = attn.self.value(hidden_states)\n",
"query_layer = attn.self.transpose_for_scores(mixed_query_layer)\n",
"key_layer = attn.self.transpose_for_scores(mixed_key_layer)\n",
"value_layer = attn.self.transpose_for_scores(mixed_value_layer)\n",
"\n",
"print('\\n◇ セルフアテンション層の Q(64次元に写像 × 16ヘッド)', query_layer.size())\n",
"print('\\n◇ セルフアテンション層の K(64次元に写像 × 16ヘッド)', key_layer.size())\n",
"print('\\n◇ セルフアテンション層の V(64次元に写像 × 16ヘッド)', value_layer.size())\n",
"print('\\nトークン ID Q(0ヘッド目)(64次元のうち最初の4次元だけ)')\n",
"print('---------------------------------------------------------------')\n",
"for token, id_, vec in zip(tokens, ids, query_layer[0][0]):\n",
" print(token.ljust(11), str(id_).ljust(5), vec[:4].detach())\n",
"print('\\nトークン ID K(0ヘッド目)(64次元のうち最初の4次元だけ)')\n",
"print('---------------------------------------------------------------')\n",
"for token, id_, vec in zip(tokens, ids, key_layer[0][0]):\n",
" print(token.ljust(11), str(id_).ljust(5), vec[:4].detach())\n",
"print('\\nトークン ID V(0ヘッド目)(64次元のうち最初の4次元だけ)')\n",
"print('---------------------------------------------------------------')\n",
"for token, id_, vec in zip(tokens, ids, value_layer[0][0]):\n",
" print(token.ljust(11), str(id_).ljust(5), vec[:4].detach())\n",
"\n",
"attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))\n",
"print('\\n◇ セルフアテンション層の Q・K(ドットプロダクトしただけ)', attention_scores.size())\n",
"attention_scores = attention_scores / math.sqrt(attn.self.attention_head_size)\n",
"attention_probs = nn.Softmax(dim=-1)(attention_scores)\n",
"attention_probs = attn.self.dropout(attention_probs)\n",
"print('\\n◇ セルフアテンション層の Q・K(√64で割ってソフトマックスまで)', attention_probs.size())\n",
"print('\\nトークン ID softmax(Q・K/√64)(0ヘッド目)(14次元のうち最初の4次元だけ)')\n",
"print('---------------------------------------------------------------')\n",
"for token, id_, vec in zip(tokens, ids, attention_probs[0][0]):\n",
" print(token.ljust(11), str(id_).ljust(5), vec[:4].detach())\n",
"\n",
"context_layer = torch.matmul(attention_probs, value_layer)\n",
"print('\\n◇ セルフアテンション層の Q・K・V = Attention(Q, K, V)', context_layer.size())\n",
"context_layer = context_layer.permute(0, 2, 1, 3).contiguous()\n",
"print('\\n◇ セルフアテンション層の Attention(Q, K, V)(転置)', context_layer.size())\n",
"new_context_layer_shape = context_layer.size()[:-2] + (attn.self.all_head_size,)\n",
"context_layer = context_layer.view(*new_context_layer_shape)\n",
"self_outputs = (context_layer,)\n",
"print('\\n◇ セルフアテンション層の Attention(Q, K, V)(16ヘッドをconcat)', self_outputs[0].size())\n",
"print('\\nトークン ID セルフアテンション(1024次元のうち最初の4次元だけ)')\n",
"print('---------------------------------------------------------------')\n",
"for token, id_, vec in zip(tokens, ids, self_outputs[0][0]):\n",
" print(token.ljust(11), str(id_).ljust(5), vec[:4].detach())\n",
"\n",
"attention_output = attn.output(self_outputs[0], hidden_states)\n",
"print('\\n◇ セルフアテンション層の最終アウトプット(全結合後)', attention_output.size())\n",
"print('\\nトークン ID セルフアテンション(1024次元のうち最初の4次元だけ)')\n",
"print('---------------------------------------------------------------')\n",
"for token, id_, vec in zip(tokens, ids, attention_output[0]):\n",
" print(token.ljust(11), str(id_).ljust(5), vec[:4].detach())\n",
" \n",
"print('\\nエンコーダ層内の0層目のセルフアテンション層と同じになっている.')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment