Skip to content

Instantly share code, notes, and snippets.

@CookieBox26
Created April 5, 2020 03:21
Show Gist options
  • Save CookieBox26/b48e347c4ece2fe9763aa72d54162bdc to your computer and use it in GitHub Desktop.
Save CookieBox26/b48e347c4ece2fe9763aa72d54162bdc to your computer and use it in GitHub Desktop.
PyTorch で Transformer を学習する
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# PyTorch で Transformer を学習する\n",
"\n",
"### 参考文献\n",
"\n",
"SEQUENCE-TO-SEQUENCE MODELING WITH NN.TRANSFORMER AND TORCHTEXT \n",
"https://pytorch.org/tutorials/beginner/transformer_tutorial.html\n",
"- [torchtext.datasets.WikiText2](https://torchtext.readthedocs.io/en/latest/datasets.html#torchtext.datasets.WikiText2)\n",
"- [torchtext.data.Field](https://torchtext.readthedocs.io/en/latest/data.html#field)\n",
"- [torchtext.data.get_tokenizer](https://pytorch.org/text/data.html#torchtext.data.get_tokenizer)\n",
"- [neural network - PyTorch - contiguous() - Stack Overflow](https://stackoverflow.com/questions/48915810/pytorch-contiguous)\n",
" - view() や transpose() はインデックスをふりかえるだけで新しいオブジェクトを生成しないが、これをつけると新しいインデックスに適切なメモリ上の配置で新しいオブジェクトを生成する。プレ処理中に転置などをするなら処理速度上 contiguous しておいた方がよい場合がありそう。\n",
"- [torch.triu](https://pytorch.org/docs/stable/torch.html#torch.triu)\n",
" - デフォルトで行列を広義上三角にしてくる。オプションで狭義上三角にもできるし逆に対角要素より何行下まで残すこともできる。\n",
"- [torch.nn.Embedding](https://pytorch.org/docs/master/nn.html#torch.nn.Embedding)\n",
"- [torch.nn.Transformer](https://pytorch.org/docs/master/nn.html?highlight=nn%20transformer#torch.nn.Transformer)\n",
" - 以下の論文にもとづいている.\n",
" - [Attention Is All You Need](https://arxiv.org/abs/1706.03762)\n",
"- [torch.nn.TransformerEncoder](https://pytorch.org/docs/master/nn.html?highlight=nn%20transformerencoder#torch.nn.TransformerEncoder)\n",
"- [torch.nn.TransformerEncoderLayer](https://pytorch.org/docs/master/nn.html?highlight=transformerencoderlayer#torch.nn.TransformerEncoderLayer)\n",
"- [torch.nn.MultiheadAttention](https://pytorch.org/docs/master/nn.html?highlight=multiheadattention#torch.nn.MultiheadAttention)\n",
" - 今回はこのモジュールを生でつかうことはしない。TransformerEncoderLayer を定義する。"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ まず Wikitext-2 データを取得する\n",
"◇ 訓練データ語数: 2086708\n",
" 訓練データ冒頭: ['<eos>', '=', 'valkyria', 'chronicles', 'iii', '=', '<eos>', '<eos>', 'senjō', 'no', 'valkyria', '3', '<unk>', 'chronicles', '(', 'japanese', '戦場のヴァルキュリア3', ',', 'lit', '.', 'valkyria', 'of', 'the', 'battlefield', '3', ')', ',', 'commonly', 'referred', 'to', 'as', 'valkyria', 'chronicles', 'iii', 'outside', 'japan', ',', 'is', 'a', 'tactical']\n",
"◇ 評価データ語数: 218177\n",
" 評価データ冒頭: ['<eos>', '=', 'homarus', 'gammarus', '=', '<eos>', '<eos>', 'homarus', 'gammarus', ',', 'known', 'as', 'the', 'european', 'lobster', 'or', 'common', 'lobster', ',', 'is', 'a', 'species', 'of', '<unk>', 'lobster', 'from', 'the', 'eastern', 'atlantic', 'ocean', ',', 'mediterranean', 'sea', 'and', 'parts', 'of', 'the', 'black', 'sea', '.']\n",
"◇ テストデータ語数: 246217\n",
" テストデータ冒頭: ['<eos>', '=', 'robert', '<unk>', '=', '<eos>', '<eos>', 'robert', '<unk>', 'is', 'an', 'english', 'film', ',', 'television', 'and', 'theatre', 'actor', '.', 'he', 'had', 'a', 'guest', '@-@', 'starring', 'role', 'on', 'the', 'television', 'series', 'the', 'bill', 'in', '2000', '.', 'this', 'was', 'followed', 'by', 'a']\n",
"◇ 訓練データの語彙を取得する\n",
"語彙数: 28785\n",
"0 <unk>\n",
"1 <pad>\n",
"2 <sos>\n",
"3 <eos>\n",
"4 the\n",
"5 ,\n",
"6 .\n",
"7 of\n",
"8 and\n",
"9 in\n",
"10 to\n",
"11 a\n",
"12 =\n",
"13 was\n",
"14 '\n",
"15 @-@\n",
"16 on\n",
"17 as\n",
"18 s\n",
"19 that\n",
"20 for\n"
]
}
],
"source": [
"import torch\n",
"import torchtext\n",
"from torchtext.data.utils import get_tokenizer\n",
"\n",
"print('◆ まず Wikitext-2 データを取得する')\n",
"# torchtext.data.Field でテキストデータへの前処理を定義しておく\n",
"# トークナイザは今回は basic_english というのをつかう(これは空白で split するより先に正規化するらしい)\n",
"# > which normalize the string first and split by space.\n",
"# 今回は init_token と eos_token を定義しておく(全ての例文に文頭トークンと文末トークンが付加される;デフォルトは付加しない)\n",
"# 今回は小文字化 lower=True にしておく(デフォルトは False)\n",
"TEXT = torchtext.data.Field(tokenize=get_tokenizer(\"basic_english\"), init_token='<sos>', eos_token='<eos>', lower=True)\n",
"# 訓練データ、評価データ、テストデータを取得する\n",
"train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(TEXT)\n",
"print('◇ 訓練データ語数: {}'.format(len(train_txt.examples[0].text)))\n",
"print(' 訓練データ冒頭: ', train_txt.examples[0].text[:40])\n",
"print('◇ 評価データ語数: {}'.format(len(val_txt.examples[0].text)))\n",
"print(' 評価データ冒頭: ', val_txt.examples[0].text[:40])\n",
"print('◇ テストデータ語数: {}'.format(len(test_txt.examples[0].text)))\n",
"print(' テストデータ冒頭: ', test_txt.examples[0].text[:40])\n",
"\n",
"print('◇ 訓練データの語彙を取得する')\n",
"TEXT.build_vocab(train_txt) # 訓練データをもとに語彙を構成(これにもとづいて各単語を数字にする)\n",
"print('語彙数: {}'.format(len(TEXT.vocab.stoi)))\n",
"for token, i in TEXT.vocab.stoi.items():\n",
" print(i, token)\n",
" if i == 20:\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ データを固定長ごとに区切って整形する\n",
"訓練データ torch.Size([104335, 20])\n",
"評価データ torch.Size([21817, 10])\n",
"テストデータ torch.Size([24621, 10])\n",
"\n",
"◇ 訓練データの 0, 1, 2, 33, 34, 35 個目の単語列(数字列版)\n",
"単語列0 tensor([ 3, 25, 1849, 570, 7, 5, 5, 9258, 4, 56, 0, 7,\n",
" 6, 6634, 4, 6603, 6, 5, 65, 30])\n",
"単語列1 tensor([ 12, 66, 13, 4889, 458, 8, 1045, 21, 19094, 34,\n",
" 147, 4, 0, 10, 2280, 2294, 58, 35, 2438, 4064])\n",
"単語列2 tensor([ 3852, 13667, 2962, 68, 6, 28374, 39, 417, 0, 2034,\n",
" 29, 88, 27804, 350, 7, 17, 4811, 902, 33, 20])\n",
"単語列33 tensor([ 884, 28, 27, 435, 12, 6, 63, 133, 6, 0, 13, 31,\n",
" 3, 43, 8, 2997, 78, 5977, 52, 181])\n",
"単語列34 tensor([ 632, 4, 127, 6, 3, 4775, 4, 4, 25, 23, 124, 5223,\n",
" 12, 194, 4, 8, 36, 2142, 139, 37])\n",
"単語列35 tensor([ 979, 725, 4, 10997, 3, 8, 677, 542, 112, 55,\n",
" 126, 11, 12, 6, 6962, 4113, 5196, 37, 27, 9177])\n",
"\n",
"◇ 訓練データの 0, 1, 2 個目の単語列(単語列に翻訳版)\n",
"単語列0 ['<eos>', '@', 'settlement', 'heavy', 'of', ',', ',', 'lined', 'the', 'she', '<unk>', 'of', '.', 'interception', 'the', 'dried', '.', ',', 'would', 'his']\n",
"単語列1 ['=', '1', 'was', 'rains', 'ireland', 'and', 'starting', 'with', 'hairy', 'had', 'found', 'the', '<unk>', 'to', 'possibility', 'heads', 'other', 'which', 'receive', 'gift']\n",
"単語列2 ['valkyria', 'rebounds', 'rapid', 'over', '.', 'truely', 'their', 'lead', '<unk>', 'recently', 'at', 'year', 'sharif', 'take', 'of', 'as', 'symptoms', 'usually', 'an', 'for']\n",
"\n",
"今回学習するモデルは、\n",
"Inputs として 単語列0 を入力したら 単語列1 っぽい確率分布が出てくる.\n",
"Inputs として 単語列1 を入力したら 単語列2 っぽい確率分布が出てくる.\n",
"任意の長さの単語列を入力したら、それに続きそうな同じ長さの単語列が出てくる.\n",
"たぶんそんなモデル.\n"
]
}
],
"source": [
"print('◆ データを固定長ごとに区切って整形する')\n",
"\n",
"def batchify(data, seq_len):\n",
" data = TEXT.numericalize([data.examples[0].text]) # データ内の単語列を数字列へ\n",
" # print(data.size()) # torch.Size([2086708, 1])\n",
" nbatch = data.size(0) // seq_len \n",
" data = data.narrow(0, 0, nbatch * seq_len) # 系列長で割り切れない単語は捨てる ヾ(;ω;) \n",
" # print(data.size()) # torch.Size([2086700, 1])\n",
" data = data.view(seq_len, -1).t().contiguous() # 次元を (ほげ, 系列長) にする\n",
" # print(data.size()) # torch.Size([104335, 20])\n",
" return data.to('cpu')\n",
"\n",
"seq_len = 20\n",
"eval_seq_len = 10\n",
"\n",
"train_data = batchify(train_txt, seq_len)\n",
"val_data = batchify(val_txt, eval_seq_len)\n",
"test_data = batchify(test_txt, eval_seq_len)\n",
"\n",
"print('訓練データ', train_data.size())\n",
"print('評価データ', val_data.size())\n",
"print('テストデータ', test_data.size())\n",
"\n",
"print('\\n◇ 訓練データの 0, 1, 2, 33, 34, 35 個目の単語列(数字列版)')\n",
"print('単語列0', train_data[0,:])\n",
"print('単語列1', train_data[1,:])\n",
"print('単語列2', train_data[2,:])\n",
"print('単語列33', train_data[33,:])\n",
"print('単語列34', train_data[34,:])\n",
"print('単語列35', train_data[35,:])\n",
"\n",
"print('\\n◇ 訓練データの 0, 1, 2 個目の単語列(単語列に翻訳版)')\n",
"print('単語列0', [TEXT.vocab.itos[i] for i in train_data[0,:]])\n",
"print('単語列1', [TEXT.vocab.itos[i] for i in train_data[1,:]])\n",
"print('単語列2', [TEXT.vocab.itos[i] for i in train_data[2,:]])\n",
"\n",
"print('\\n今回学習するモデルは、')\n",
"print('Inputs として 単語列0 を入力したら 単語列1 っぽい確率分布が出てくる.')\n",
"print('Inputs として 単語列1 を入力したら 単語列2 っぽい確率分布が出てくる.')\n",
"print('任意の長さの単語列を入力したら、それに続きそうな同じ長さの単語列が出てくる.')\n",
"print('たぶんそんなモデル.')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ 1バッチの学習に必要なデータを取得する関数を用意する\n",
"◇ data と targets が得られる\n",
"torch.Size([35, 20])\n",
"torch.Size([700])\n",
"◇ data には Inputs から入れる用に訓練データの 0~34 個目の単語列が入っている\n",
"tensor([ 3, 25, 1849, 570, 7]) 単語列0\n",
"tensor([632, 4, 127, 6, 3]) 単語列34\n",
"◇ targets には損失計算用に訓練データの 1~35 個目の単語列が入っている(損失計算時の便利用に1次元配列にしてある)\n",
"tensor([ 12, 66, 13, 4889, 458]) 単語列1\n",
"tensor([ 979, 725, 4, 10997, 3]) 単語列35\n"
]
}
],
"source": [
"print('◆ 1バッチの学習に必要なデータを取得する関数を用意する')\n",
"bptt = 35 # bptt 個の単語列を一度にモデルに流す\n",
"def get_batch(source, i):\n",
" seq_len = min(bptt, len(source) - 1 - i)\n",
" data = source[i:i+seq_len]\n",
" target = source[i+1:i+1+seq_len].view(-1)\n",
" return data, target\n",
"\n",
"data, targets = get_batch(train_data, 0) # 1バッチ目の学習に必要なデータ\n",
"print('◇ data と targets が得られる')\n",
"print(data.size())\n",
"print(targets.size())\n",
"\n",
"print('◇ data には Inputs から入れる用に訓練データの 0~34 個目の単語列が入っている')\n",
"print(data[0, :5], '単語列0')\n",
"print(data[34, :5], '単語列34')\n",
"print('◇ targets には損失計算用に訓練データの 1~35 個目の単語列が入っている(損失計算時の便利用に1次元配列にしてある)')\n",
"print(targets[:5], '単語列1')\n",
"print(targets[680:685], '単語列35')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ モデル本体より先に、PositionalEncoding という単語列を埋め込んだテンソルに少しプレ処理する機能を用意する\n",
"埋め込み次元数の半分の長さ(100)まで指数的に減衰する成分が用意される\n",
"tensor([1.0000, 0.9120, 0.8318, 0.7586, 0.6918, 0.6310])\n",
"tensor([0.0002, 0.0002, 0.0001, 0.0001, 0.0001, 0.0001])\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 648x144 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"pe というテンソルの i 個目の偶数番目と奇数番目に減衰する正弦波と余弦波が用意される\n",
"torch.Size([5000, 200])\n",
"tensor([[ 0.0000, 1.0000, 0.0000, 1.0000],\n",
" [ 0.8415, 0.5403, 0.7907, 0.6122],\n",
" [ 0.9093, -0.4161, 0.9681, -0.2505]])\n",
"tensor([[ 0.9563, -0.2925, 0.9055, -0.4243],\n",
" [ 0.2705, -0.9627, 0.2192, -0.9757],\n",
" [-0.6639, -0.7478, -0.6374, -0.7705]])\n"
]
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 648x288 with 2 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"pe というテンソルはこうなる\n",
"torch.Size([5000, 1, 200])\n",
"tensor([[[ 0.0000, 1.0000, 0.0000, 1.0000]],\n",
"\n",
" [[ 0.8415, 0.5403, 0.7907, 0.6122]],\n",
"\n",
" [[ 0.9093, -0.4161, 0.9681, -0.2505]]])\n",
"tensor([[[ 0.9563, -0.2925, 0.9055, -0.4243]],\n",
"\n",
" [[ 0.2705, -0.9627, 0.2192, -0.9757]],\n",
"\n",
" [[-0.6639, -0.7478, -0.6374, -0.7705]]])\n",
"\n",
"◆ PositionalEncoder に単語列を埋め込んだテンソルを流してみる\n",
"◇ 訓練データの 0~34 個目の単語列\n",
"torch.Size([35, 20])\n",
"◇ 訓練データの 0~34 個目の単語列を200次元に埋め込み\n",
" (各単語列長は本当は20だが、5だけ表示)\n",
" (各単語の次元は本当は200だが、4だけ表示)\n",
"torch.Size([35, 20, 200])\n",
"tensor([[ 0.1248, 0.0463, 0.9071, -0.4357],\n",
" [-1.1538, -1.9721, 0.0941, -1.3863],\n",
" [ 2.2940, 1.2065, 0.6837, -1.8190],\n",
" [ 1.1506, 0.5202, 0.5958, 2.0696],\n",
" [-0.0296, -0.7435, 0.1943, 0.9146]], grad_fn=<SliceBackward>) 単語列0\n",
"tensor([[-0.0889, 2.2797, -0.9817, -0.5639],\n",
" [ 1.3507, -0.5374, -0.4463, -1.1063],\n",
" [-0.0830, 1.2043, -0.6070, 1.6978],\n",
" [-0.8962, 2.1992, 0.9791, 0.0878],\n",
" [-0.6931, -1.2100, 1.3738, -0.9831]], grad_fn=<SliceBackward>) 単語列1\n",
"tensor([[ 9.7651e-01, 5.0865e-01, -1.3232e+00, -1.2150e+00],\n",
" [ 7.5843e-01, 1.4521e+00, -1.0745e+00, 1.5824e-03],\n",
" [-1.4427e-02, -3.1398e-01, -1.7494e-01, 3.2904e-01],\n",
" [-1.1697e+00, -2.6854e-01, 8.9997e-02, 2.1359e+00],\n",
" [-5.5698e-01, -3.9223e-02, 1.7132e+00, -4.5476e-02]],\n",
" grad_fn=<SliceBackward>) 単語列2\n",
"◇ 訓練データの 0~34 個目の単語列を200次元に埋め込みして PositionalEncoding\n",
"torch.Size([35, 20, 200])\n",
"tensor([[ 0.1248, 1.0463, 0.9071, 0.5643],\n",
" [-1.1538, -0.9721, 0.0941, -0.3863],\n",
" [ 2.2940, 2.2065, 0.6837, -0.8190],\n",
" [ 1.1506, 1.5202, 0.5958, 3.0696],\n",
" [-0.0296, 0.2565, 0.1943, 1.9146]], grad_fn=<SliceBackward>) 単語列0\n",
"tensor([[ 0.7525, 2.8200, -0.1910, 0.0482],\n",
" [ 2.1922, 0.0029, 0.3444, -0.4941],\n",
" [ 0.7585, 1.7446, 0.1837, 2.3100],\n",
" [-0.0547, 2.7395, 1.7699, 0.6999],\n",
" [ 0.1483, -0.6697, 2.1645, -0.3710]], grad_fn=<SliceBackward>) 単語列1\n",
"tensor([[ 1.8858, 0.0925, -0.3551, -1.4655],\n",
" [ 1.6677, 1.0360, -0.1064, -0.2489],\n",
" [ 0.8949, -0.7301, 0.7932, 0.0785],\n",
" [-0.2604, -0.6847, 1.0581, 1.8854],\n",
" [ 0.3523, -0.4554, 2.6813, -0.2960]], grad_fn=<SliceBackward>) 単語列2\n"
]
}
],
"source": [
"import math\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"from pylab import rcParams\n",
"rcParams['figure.figsize'] = 9, 4\n",
"rcParams['font.size'] = 12\n",
"rcParams['font.family']='Ume Hy Gothic O5'\n",
"\n",
"print('◆ モデル本体より先に、PositionalEncoding という単語列を埋め込んだテンソルに少しプレ処理する機能を用意する')\n",
"\n",
"class PositionalEncoding(nn.Module):\n",
"\n",
" def __init__(self, d_model, dropout=0.1, max_len=5000, debug=False):\n",
" super(PositionalEncoding, self).__init__()\n",
" self.dropout = nn.Dropout(p=dropout)\n",
" pe = torch.zeros(max_len, d_model)\n",
" position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # [[0], [1], ..., [4999]]\n",
" div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\n",
" if debug:\n",
" print('埋め込み次元数の半分の長さ(100)まで指数的に減衰する成分が用意される')\n",
" print(div_term[:6])\n",
" print(div_term[94:])\n",
" fig, ax = plt.subplots(1, 1, sharex='col', figsize=(9, 2))\n",
" ax.set_title('埋め込み次元数の半分の長さ(100)まで指数的に減衰する成分')\n",
" ax.plot(div_term)\n",
" plt.show()\n",
" pe[:, 0::2] = torch.sin(position * div_term)\n",
" pe[:, 1::2] = torch.cos(position * div_term)\n",
" if debug:\n",
" print('pe というテンソルの i 個目の偶数番目と奇数番目に減衰する正弦波と余弦波が用意される')\n",
" print(pe.size())\n",
" print(pe[:3,:4])\n",
" print(pe[4997:,:4]) \n",
" fig, ax = plt.subplots(2, 1, sharex='col', figsize=(9, 4))\n",
" ax[0].set_title('pe の i 個目の偶数番目(本当は5000個まで用意される)')\n",
" ax[1].set_title('pe の i 個目の奇数番目(本当は5000個まで用意される)')\n",
" for i in range(10):\n",
" ax[0].plot(pe[i, 0::2])\n",
" ax[1].plot(pe[i, 1::2])\n",
" plt.show()\n",
" pe = pe.unsqueeze(0).transpose(0, 1)\n",
" if debug:\n",
" print('pe というテンソルはこうなる')\n",
" print(pe.size())\n",
" print(pe[:3,:,:4])\n",
" print(pe[4997:,:,:4]) \n",
" self.register_buffer('pe', pe)\n",
"\n",
" def forward(self, x):\n",
" x = x + self.pe[:x.size(0), :]\n",
" return self.dropout(x)\n",
"\n",
"pos_encoder = PositionalEncoding(d_model=200, dropout=0.0, debug=True)\n",
"\n",
"print('\\n◆ PositionalEncoder に単語列を埋め込んだテンソルを流してみる')\n",
"encoder = nn.Embedding(28785, 200)\n",
"print('◇ 訓練データの 0~34 個目の単語列')\n",
"print(data.size())\n",
"data_encoded = encoder(data)\n",
"print('◇ 訓練データの 0~34 個目の単語列を200次元に埋め込み')\n",
"print(' (各単語列長は本当は20だが、5だけ表示)')\n",
"print(' (各単語の次元は本当は200だが、4だけ表示)')\n",
"print(data_encoded.size())\n",
"print(data_encoded[0, :5, :4], '単語列0')\n",
"print(data_encoded[1, :5, :4], '単語列1')\n",
"print(data_encoded[2, :5, :4], '単語列2')\n",
"data_pos_encoded = pos_encoder.forward(data_encoded)\n",
"print('◇ 訓練データの 0~34 個目の単語列を200次元に埋め込みして PositionalEncoding')\n",
"print(data_pos_encoded.size())\n",
"print(data_pos_encoded[0, :5, :4], '単語列0')\n",
"print(data_pos_encoded[1, :5, :4], '単語列1')\n",
"print(data_pos_encoded[2, :5, :4], '単語列2')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"単語列0の各単語には [ 0.0000, 1.0000, 0.0000, 1.0000, ...] というベクトルが足される.\n",
"単語列1の各単語には [ 0.8415, 0.5403, 0.7907, 0.6122, ...] というベクトルが足される.\n",
"単語列2の各単語には [ 0.9093, -0.4161, 0.9681, -0.2505, ...] というベクトルが足される.\n",
"バッチ内で何番目の単語列かによって、後の方ほど振動数が大きい正弦波/余弦波を単語に上乗せするのが PositionalEncoding.\n",
"※ バッチ内で何番目の単語列かという情報を振動数が違う正弦波/余弦波で表現するのはあくまでこのチュートリアルが採用した方法であり他の方法でもよい.\n"
]
}
],
"source": [
"print('単語列0の各単語には [ 0.0000, 1.0000, 0.0000, 1.0000, ...] というベクトルが足される.')\n",
"print('単語列1の各単語には [ 0.8415, 0.5403, 0.7907, 0.6122, ...] というベクトルが足される.')\n",
"print('単語列2の各単語には [ 0.9093, -0.4161, 0.9681, -0.2505, ...] というベクトルが足される.')\n",
"print('バッチ内で何番目の単語列かによって、後の方ほど振動数が大きい正弦波/余弦波を単語に上乗せするのが PositionalEncoding.')\n",
"print('※ バッチ内で何番目の単語列かという情報を振動数が違う正弦波/余弦波で表現するのはあくまでこのチュートリアルが採用した方法であり他の方法でもよい.')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ モデル本体を用意する\n",
"TransformerEncoderLayer = 2HeadAttention + 200次元の隠れ層が1層のFFN\n",
"今回のモデル = エンコーダ + PositionalEncoder + TransformerEncoderLayerその0 + TransformerEncoderLayerその1 + デコーダ\n",
"\n",
"◆ 今回学習するパラメータたち(TransformerEncoderLayerその0、TransformerEncoderLayerその1、エンコーダ、デコーダ)\n",
"transformer_encoder.layers.0.self_attn.in_proj_weight torch.Size([600, 200])\n",
"transformer_encoder.layers.0.self_attn.in_proj_bias torch.Size([600])\n",
"transformer_encoder.layers.0.self_attn.out_proj.weight torch.Size([200, 200])\n",
"transformer_encoder.layers.0.self_attn.out_proj.bias torch.Size([200])\n",
"transformer_encoder.layers.0.linear1.weight torch.Size([200, 200])\n",
"transformer_encoder.layers.0.linear1.bias torch.Size([200])\n",
"transformer_encoder.layers.0.linear2.weight torch.Size([200, 200])\n",
"transformer_encoder.layers.0.linear2.bias torch.Size([200])\n",
"transformer_encoder.layers.0.norm1.weight torch.Size([200])\n",
"transformer_encoder.layers.0.norm1.bias torch.Size([200])\n",
"transformer_encoder.layers.0.norm2.weight torch.Size([200])\n",
"transformer_encoder.layers.0.norm2.bias torch.Size([200])\n",
"transformer_encoder.layers.1.self_attn.in_proj_weight torch.Size([600, 200])\n",
"transformer_encoder.layers.1.self_attn.in_proj_bias torch.Size([600])\n",
"transformer_encoder.layers.1.self_attn.out_proj.weight torch.Size([200, 200])\n",
"transformer_encoder.layers.1.self_attn.out_proj.bias torch.Size([200])\n",
"transformer_encoder.layers.1.linear1.weight torch.Size([200, 200])\n",
"transformer_encoder.layers.1.linear1.bias torch.Size([200])\n",
"transformer_encoder.layers.1.linear2.weight torch.Size([200, 200])\n",
"transformer_encoder.layers.1.linear2.bias torch.Size([200])\n",
"transformer_encoder.layers.1.norm1.weight torch.Size([200])\n",
"transformer_encoder.layers.1.norm1.bias torch.Size([200])\n",
"transformer_encoder.layers.1.norm2.weight torch.Size([200])\n",
"transformer_encoder.layers.1.norm2.bias torch.Size([200])\n",
"encoder.weight torch.Size([28785, 200])\n",
"decoder.weight torch.Size([28785, 200])\n",
"decoder.bias torch.Size([28785])\n",
"\n",
"◆ モデルに単語列を流してみる\n",
"◇ 訓練データの 0~34 個目の単語列を流す\n",
"torch.Size([35, 20])\n",
"tensor([ 3, 25, 1849, 570, 7]) 単語列0\n",
"tensor([632, 4, 127, 6, 3]) 単語列34\n",
"----- TransformerEncoder層でこんなマスクをつかうよ -----\n",
"TransformerEncoder層は入力側の単語空間から出力側の単語空間への写像であるはずである.\n",
"1バッチ目の学習で 0~34 個目の単語列を流すが、0個目の単語列を写像するときに\n",
"1~34 個目の単語列を利用しないようにするためにマスクする必要がある.\n",
"具体的に、負の無限大をいれておくことで未来の単語列由来の Attention が発生しないようにしていると思う.\n",
"torch.Size([35, 35])\n",
"tensor([[0., -inf, -inf, -inf, -inf],\n",
" [0., 0., -inf, -inf, -inf],\n",
" [0., 0., 0., -inf, -inf],\n",
" [0., 0., 0., 0., -inf],\n",
" [0., 0., 0., 0., 0.]])\n",
"--------------------------------------------------------\n",
"----- 各単語を埋め込んだよ -----\n",
"torch.Size([35, 20, 200])\n",
"tensor([[-0.6279, -0.6731, -1.0157, 0.6879],\n",
" [ 1.3091, 0.7319, -0.1179, 0.0892],\n",
" [-0.2110, 0.7331, -1.1495, 1.0249],\n",
" [ 0.3094, 1.3982, 1.0791, 0.5063],\n",
" [ 0.1669, 0.1000, -0.5802, 0.2259]], grad_fn=<SliceBackward>) 単語列0\n",
"tensor([[ 0.9145, -0.9085, 1.1073, 0.1067],\n",
" [-1.0746, 0.4230, 0.5442, -1.2081],\n",
" [ 0.7832, 1.1152, -0.8875, -1.2305],\n",
" [ 0.0801, 0.5733, 1.1932, 1.3263],\n",
" [-0.6279, -0.6731, -1.0157, 0.6879]], grad_fn=<SliceBackward>) 単語列34\n",
"--------------------------------\n",
"----- PositionalEncoding したよ -----\n",
"torch.Size([35, 20, 200])\n",
"tensor([[-0.7849, 0.4086, -0.0000, 2.1098],\n",
" [ 0.0000, 0.0000, -0.1474, 1.3615],\n",
" [-0.2637, 2.1664, -1.4369, 2.5311],\n",
" [ 0.3867, 2.9978, 1.3488, 1.8829],\n",
" [ 0.0000, 1.3750, -0.7252, 1.5324]], grad_fn=<SliceBackward>) 単語列0\n",
"tensor([[ 1.8045, -2.1964, 0.8887, 1.2809],\n",
" [-0.6819, -0.5320, 0.1848, -0.3626],\n",
" [ 1.6403, 0.0000, -1.6049, -0.3905],\n",
" [ 0.7614, -0.3440, 0.9961, 2.8055],\n",
" [-0.1235, -1.9021, -1.7650, 2.0074]], grad_fn=<SliceBackward>) 単語列34\n",
"-------------------------------------\n",
"----- TransformerEncoder層に通したよ -----\n",
"torch.Size([35, 20, 200])\n",
"tensor([[-0.1566, 0.1365, -0.2389, -1.1040],\n",
" [-0.8418, -0.6476, -0.3431, 1.2286],\n",
" [-1.0668, 0.8735, -1.5316, 1.1909],\n",
" [ 0.0748, 0.6588, -1.0968, 0.5347],\n",
" [-0.4703, -0.2646, -1.3832, 0.0422]], grad_fn=<SliceBackward>) 単語列0につづくことが期待される単語列(埋め込み版)\n",
"tensor([[ 1.1223, -2.0096, 0.2688, 0.4822],\n",
" [-0.6019, -0.7603, -0.3600, -0.4463],\n",
" [ 1.2852, -0.4725, -1.2806, -0.6672],\n",
" [ 0.1246, -0.1841, 0.5772, 1.0651],\n",
" [ 0.8009, -2.0562, -1.0535, 1.3696]], grad_fn=<SliceBackward>) 単語列34につづくことが期待される単語列(埋め込み版)\n",
"------------------------------------------\n",
"----- 200次元の単語空間から28785個の単語上の確率分布に戻すよ(モデル内でSoftmaxまではしない) -----\n",
"torch.Size([35, 20, 28785])\n",
"tensor([[ 0.7065, 0.0610, 0.6539, 0.3053],\n",
" [-0.0994, -1.2828, -0.5313, 0.8495],\n",
" [ 0.2614, 0.1559, 1.2580, 0.2775],\n",
" [ 0.1664, -0.2835, -0.3288, 1.1031],\n",
" [-0.2702, 0.2202, 1.8064, -0.3899]], grad_fn=<SliceBackward>) 単語列0につづくことが期待される単語列(単語上の分布版)\n",
"tensor([[ 0.4869, -0.3151, 0.5294, 0.1446],\n",
" [-0.0557, -0.9696, -0.8336, 0.3696],\n",
" [ 0.5967, -0.5712, 0.8610, 0.3025],\n",
" [ 0.2908, 0.0891, 0.0876, -0.0476],\n",
" [ 1.3940, -0.4024, 0.5641, 0.7353]], grad_fn=<SliceBackward>) 単語列34につづくことが期待される単語列(単語上の分布版)\n",
"---------------------------------------------------------------------------------------------------\n",
"\n",
"◇ 単語列0につづくことが期待される予測単語列\n",
"[26292, 24561, 17166, 1744, 10627, 4781, 12237, 21242, 27323, 4728, 27100, 25329, 21128, 25945, 12959, 25505, 9062, 4781, 22295, 2187]\n",
"単語列に翻訳すると\n",
"['kreutzer', 'campanian', 'henriette', 'onto', 'pageant', 'numbered', 'berardi', 'delineated', 'pyramids', '1903', 'pentominoes', 'erwin', 'conformational', 'hxg4', '560', 'fleas', 'allegedly', 'numbered', 'mira', 'jane']\n"
]
}
],
"source": [
"print('◆ モデル本体を用意する')\n",
"\n",
"print('TransformerEncoderLayer = 2HeadAttention + 200次元の隠れ層が1層のFFN')\n",
"print('今回のモデル = エンコーダ + PositionalEncoder + TransformerEncoderLayerその0 + TransformerEncoderLayerその1 + デコーダ')\n",
"\n",
"class TransformerModel(nn.Module):\n",
" # ntoken : 語彙数\n",
" # ninp : 埋め込む次元\n",
" def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):\n",
" super(TransformerModel, self).__init__()\n",
" from torch.nn import TransformerEncoder, TransformerEncoderLayer\n",
" self.model_type = 'Transformer'\n",
" self.src_mask = None\n",
" self.pos_encoder = PositionalEncoding(ninp, dropout) # さっきの PositionalEncoder\n",
" encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)\n",
" self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)\n",
" self.encoder = nn.Embedding(ntoken, ninp)\n",
" self.ninp = ninp\n",
" self.decoder = nn.Linear(ninp, ntoken)\n",
" self.init_weights()\n",
"\n",
" def _generate_square_subsequent_mask(self, sz, debug=False):\n",
" mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)\n",
" mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))\n",
" if debug:\n",
" print('----- TransformerEncoder層でこんなマスクをつかうよ -----')\n",
" print('TransformerEncoder層は入力側の単語空間から出力側の単語空間への写像であるはずである.')\n",
" print('1バッチ目の学習で 0~34 個目の単語列を流すが、0個目の単語列を写像するときに')\n",
" print('1~34 個目の単語列を利用しないようにするためにマスクする必要がある.')\n",
" print('具体的に、負の無限大をいれておくことで未来の単語列由来の Attention が発生しないようにしていると思う.')\n",
" print(mask.size())\n",
" print(mask[:5,:5])\n",
" print('--------------------------------------------------------')\n",
" return mask\n",
"\n",
" def init_weights(self):\n",
" initrange = 0.1\n",
" self.encoder.weight.data.uniform_(-initrange, initrange)\n",
" self.decoder.bias.data.zero_()\n",
" self.decoder.weight.data.uniform_(-initrange, initrange)\n",
"\n",
" def forward(self, src, debug=False):\n",
" if self.src_mask is None or self.src_mask.size(0) != len(src):\n",
" device = src.device\n",
" mask = self._generate_square_subsequent_mask(len(src), debug).to(device)\n",
" self.src_mask = mask\n",
" src = self.encoder(src) * math.sqrt(self.ninp)\n",
" if debug:\n",
" print('----- 各単語を埋め込んだよ -----')\n",
" print(src.size())\n",
" print(src[0, :5, :4], '単語列0')\n",
" print(src[34, :5, :4], '単語列34')\n",
" print('--------------------------------')\n",
" src = self.pos_encoder(src)\n",
" if debug:\n",
" print('----- PositionalEncoding したよ -----')\n",
" print(src.size())\n",
" print(src[0, :5, :4], '単語列0')\n",
" print(src[34, :5, :4], '単語列34')\n",
" print('-------------------------------------')\n",
" output = self.transformer_encoder(src, self.src_mask)\n",
" if debug:\n",
" print('----- TransformerEncoder層に通したよ -----')\n",
" print(output.size())\n",
" print(output[0, :5, :4], '単語列0につづくことが期待される単語列(埋め込み版)')\n",
" print(output[34, :5, :4], '単語列34につづくことが期待される単語列(埋め込み版)')\n",
" print('------------------------------------------')\n",
" output = self.decoder(output)\n",
" if debug:\n",
" print('----- 200次元の単語空間から28785個の単語上の確率分布に戻すよ(モデル内でSoftmaxまではしない) -----')\n",
" print(output.size())\n",
" print(output[0, :5, :4], '単語列0につづくことが期待される単語列(単語上の分布版)')\n",
" print(output[34, :5, :4], '単語列34につづくことが期待される単語列(単語上の分布版)')\n",
" print('---------------------------------------------------------------------------------------------------')\n",
" return output\n",
" \n",
"ntokens = len(TEXT.vocab.stoi) # 訓練データの語彙数(上述)\n",
"emsize = 200 # 単語を埋め込む次元\n",
"nhid = 200 # TransformerEncoder層の FFN の隠れ層の次元\n",
"nlayers = 2 # TransformerEncoder層を何層重ねるか\n",
"nhead = 2 # TransformerEncoder層内のMultiHeadAttention層のヘッド数\n",
"dropout = 0.2\n",
"\n",
"model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to('cpu')\n",
"\n",
"print('\\n◆ 今回学習するパラメータたち(TransformerEncoderLayerその0、TransformerEncoderLayerその1、エンコーダ、デコーダ)')\n",
"for name, param in model.named_parameters():\n",
" print(name.ljust(14), param.size())\n",
" \n",
"print('\\n◆ モデルに単語列を流してみる')\n",
"print('◇ 訓練データの 0~34 個目の単語列を流す')\n",
"print(data.size())\n",
"print(data[0, :5], '単語列0')\n",
"print(data[34, :5], '単語列34')\n",
"output = model.forward(data, debug=True)\n",
"model.zero_grad()\n",
"\n",
"print('\\n◇ 単語列0につづくことが期待される予測単語列')\n",
"pred0 = [torch.max(output[:1, i, :], 1)[1].item() for i in range(20)] # max index をとる\n",
"print(pred0)\n",
"print('単語列に翻訳すると')\n",
"print([TEXT.vocab.itos[i] for i in pred0])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ 損失は交差エントロピー\n",
"◆ 訓練する\n",
"| epoch 1 | 200/ 2981 batches | lr 5.00 | ms/batch 339.02 | loss 7.99 | ppl 2964.75\n",
"| epoch 1 | 400/ 2981 batches | lr 5.00 | ms/batch 331.12 | loss 6.78 | ppl 880.85\n",
"| epoch 1 | 600/ 2981 batches | lr 5.00 | ms/batch 330.54 | loss 6.36 | ppl 578.69\n",
"| epoch 1 | 800/ 2981 batches | lr 5.00 | ms/batch 330.58 | loss 6.22 | ppl 504.51\n",
"| epoch 1 | 1000/ 2981 batches | lr 5.00 | ms/batch 332.90 | loss 6.11 | ppl 451.92\n",
"| epoch 1 | 1200/ 2981 batches | lr 5.00 | ms/batch 331.87 | loss 6.09 | ppl 440.12\n",
"| epoch 1 | 1400/ 2981 batches | lr 5.00 | ms/batch 337.76 | loss 6.05 | ppl 422.34\n",
"| epoch 1 | 1600/ 2981 batches | lr 5.00 | ms/batch 341.21 | loss 6.05 | ppl 424.50\n",
"| epoch 1 | 1800/ 2981 batches | lr 5.00 | ms/batch 340.24 | loss 5.97 | ppl 389.90\n",
"| epoch 1 | 2000/ 2981 batches | lr 5.00 | ms/batch 344.83 | loss 5.96 | ppl 386.73\n",
"| epoch 1 | 2200/ 2981 batches | lr 5.00 | ms/batch 345.09 | loss 5.85 | ppl 347.05\n",
"| epoch 1 | 2400/ 2981 batches | lr 5.00 | ms/batch 346.63 | loss 5.90 | ppl 363.87\n",
"| epoch 1 | 2600/ 2981 batches | lr 5.00 | ms/batch 349.56 | loss 5.91 | ppl 366.96\n",
"| epoch 1 | 2800/ 2981 batches | lr 5.00 | ms/batch 354.65 | loss 5.81 | ppl 332.75\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 1 | time: 1061.99s | valid loss 5.70 | valid ppl 300.33\n",
"-----------------------------------------------------------------------------------------\n",
"| epoch 2 | 200/ 2981 batches | lr 4.51 | ms/batch 364.47 | loss 5.81 | ppl 333.28\n",
"| epoch 2 | 400/ 2981 batches | lr 4.51 | ms/batch 361.44 | loss 5.77 | ppl 321.51\n",
"| epoch 2 | 600/ 2981 batches | lr 4.51 | ms/batch 360.41 | loss 5.60 | ppl 271.70\n",
"| epoch 2 | 800/ 2981 batches | lr 4.51 | ms/batch 363.80 | loss 5.64 | ppl 281.70\n",
"| epoch 2 | 1000/ 2981 batches | lr 4.51 | ms/batch 368.53 | loss 5.59 | ppl 267.25\n",
"| epoch 2 | 1200/ 2981 batches | lr 4.51 | ms/batch 362.11 | loss 5.62 | ppl 276.96\n",
"| epoch 2 | 1400/ 2981 batches | lr 4.51 | ms/batch 365.81 | loss 5.63 | ppl 278.81\n",
"| epoch 2 | 1600/ 2981 batches | lr 4.51 | ms/batch 365.77 | loss 5.66 | ppl 288.42\n",
"| epoch 2 | 1800/ 2981 batches | lr 4.51 | ms/batch 363.54 | loss 5.59 | ppl 266.53\n",
"| epoch 2 | 2000/ 2981 batches | lr 4.51 | ms/batch 365.22 | loss 5.61 | ppl 272.91\n",
"| epoch 2 | 2200/ 2981 batches | lr 4.51 | ms/batch 365.55 | loss 5.51 | ppl 246.89\n",
"| epoch 2 | 2400/ 2981 batches | lr 4.51 | ms/batch 365.03 | loss 5.57 | ppl 261.77\n",
"| epoch 2 | 2600/ 2981 batches | lr 4.51 | ms/batch 361.33 | loss 5.58 | ppl 265.62\n",
"| epoch 2 | 2800/ 2981 batches | lr 4.51 | ms/batch 362.20 | loss 5.51 | ppl 247.59\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 2 | time: 1131.27s | valid loss 5.55 | valid ppl 256.85\n",
"-----------------------------------------------------------------------------------------\n",
"| epoch 3 | 200/ 2981 batches | lr 4.29 | ms/batch 365.98 | loss 5.55 | ppl 256.72\n",
"| epoch 3 | 400/ 2981 batches | lr 4.29 | ms/batch 363.26 | loss 5.55 | ppl 258.13\n",
"| epoch 3 | 600/ 2981 batches | lr 4.29 | ms/batch 362.61 | loss 5.36 | ppl 212.86\n",
"| epoch 3 | 800/ 2981 batches | lr 4.29 | ms/batch 365.65 | loss 5.42 | ppl 224.78\n",
"| epoch 3 | 1000/ 2981 batches | lr 4.29 | ms/batch 365.14 | loss 5.38 | ppl 217.48\n",
"| epoch 3 | 1200/ 2981 batches | lr 4.29 | ms/batch 365.68 | loss 5.42 | ppl 225.63\n",
"| epoch 3 | 1400/ 2981 batches | lr 4.29 | ms/batch 360.36 | loss 5.44 | ppl 229.99\n",
"| epoch 3 | 1600/ 2981 batches | lr 4.29 | ms/batch 358.62 | loss 5.48 | ppl 239.10\n",
"| epoch 3 | 1800/ 2981 batches | lr 4.29 | ms/batch 357.22 | loss 5.41 | ppl 223.48\n",
"| epoch 3 | 2000/ 2981 batches | lr 4.29 | ms/batch 354.21 | loss 5.44 | ppl 229.41\n",
"| epoch 3 | 2200/ 2981 batches | lr 4.29 | ms/batch 355.91 | loss 5.32 | ppl 205.04\n",
"| epoch 3 | 2400/ 2981 batches | lr 4.29 | ms/batch 353.65 | loss 5.40 | ppl 221.81\n",
"| epoch 3 | 2600/ 2981 batches | lr 4.29 | ms/batch 354.29 | loss 5.42 | ppl 225.57\n",
"| epoch 3 | 2800/ 2981 batches | lr 4.29 | ms/batch 358.52 | loss 5.35 | ppl 209.82\n",
"-----------------------------------------------------------------------------------------\n",
"| end of epoch 3 | time: 1112.92s | valid loss 5.47 | valid ppl 237.13\n",
"-----------------------------------------------------------------------------------------\n",
"=========================================================================================\n",
"| End of training | test loss 5.38 | test ppl 216.82\n",
"=========================================================================================\n"
]
}
],
"source": [
"print('◆ 損失は交差エントロピー')\n",
"criterion = nn.CrossEntropyLoss()\n",
"lr = 5.0 # learning rate\n",
"optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n",
"scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)\n",
"\n",
"import time\n",
"def train(debug=False):\n",
" model.train() # Turn on the train mode\n",
" total_loss = 0.\n",
" start_time = time.time()\n",
" ntokens = len(TEXT.vocab.stoi)\n",
" for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):\n",
" data, targets = get_batch(train_data, i)\n",
" optimizer.zero_grad()\n",
" output = model(data)\n",
" loss = criterion(output.view(-1, ntokens), targets)\n",
" loss.backward()\n",
" torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)\n",
" optimizer.step()\n",
" total_loss += loss.item()\n",
" if batch % 200 == 0 and batch > 0:\n",
" cur_loss = total_loss / 200\n",
" elapsed = time.time() - start_time\n",
" print('| epoch {:3d} | {:5d}/{:5d} batches | '\n",
" 'lr {:02.2f} | ms/batch {:5.2f} | '\n",
" 'loss {:5.2f} | ppl {:8.2f}'.format(\n",
" epoch, batch, len(train_data) // bptt, scheduler.get_lr()[0],\n",
" elapsed * 1000 / 200,\n",
" cur_loss, math.exp(cur_loss)))\n",
" total_loss = 0\n",
" start_time = time.time()\n",
" if debug:\n",
" if batch == 400:\n",
" break\n",
"\n",
"def evaluate(eval_model, data_source):\n",
" eval_model.eval() # Turn on the evaluation mode\n",
" total_loss = 0.\n",
" ntokens = len(TEXT.vocab.stoi)\n",
" with torch.no_grad():\n",
" for i in range(0, data_source.size(0) - 1, bptt):\n",
" data, targets = get_batch(data_source, i)\n",
" output = eval_model(data)\n",
" output_flat = output.view(-1, ntokens)\n",
" total_loss += len(data) * criterion(output_flat, targets).item()\n",
" return total_loss / (len(data_source) - 1)\n",
"\n",
"print('◆ 訓練する')\n",
"best_val_loss = float(\"inf\")\n",
"epochs = 3 # The number of epochs\n",
"best_model = None\n",
"for epoch in range(1, epochs + 1):\n",
" epoch_start_time = time.time()\n",
" train(debug=False)\n",
" val_loss = evaluate(model, val_data)\n",
" print('-' * 89)\n",
" print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '\n",
" 'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), val_loss, math.exp(val_loss)))\n",
" print('-' * 89)\n",
" if val_loss < best_val_loss:\n",
" best_val_loss = val_loss\n",
" best_model = model\n",
" scheduler.step()\n",
" \n",
"test_loss = evaluate(best_model, test_data)\n",
"print('=' * 89)\n",
"print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(test_loss, math.exp(test_loss)))\n",
"print('=' * 89)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"◇ 単語列0につづくことが期待される予測単語列\n",
"[12, 6, 5, 1519, 4, 8, 8, 21, 40, 13, 5, 4, 3, 9, 40, 0, 3, 8, 36, 40]\n",
"単語列に翻訳すると\n",
"['=', '.', ',', 'rainfall', 'the', 'and', 'and', 'with', 'first', 'was', ',', 'the', '<eos>', 'in', 'first', '<unk>', '<eos>', 'and', 'be', 'first']\n"
]
}
],
"source": [
"print('\\n◇ 単語列0につづくことが期待される予測単語列')\n",
"output = best_model.forward(data, debug=False)\n",
"best_model.zero_grad()\n",
"pred0 = [torch.max(output[:1, i, :], 1)[1].item() for i in range(20)] # max index をとる\n",
"print(pred0)\n",
"print('単語列に翻訳すると')\n",
"print([TEXT.vocab.itos[i] for i in pred0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment