Created
September 10, 2020 14:56
-
-
Save CookieBox26/b2254f1125f8849cca3f5c12b0dcaf61 to your computer and use it in GitHub Desktop.
文字レベル Penn Treebank と TrellisNet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 文字レベル Penn Treebank と TrellisNet\n", | |
"\n", | |
"### 参考文献\n", | |
"1. https://arxiv.org/abs/1810.06682 ;TrellisNet の原論文。 \n", | |
"2. https://github.com/locuslab/trellisnet ;TrellisNet のリポジトリ。 " | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### 文字レベル Penn Treebank 用のデータを読み込む" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"◆ PTBデータをロード\n", | |
"Loading cached data...\n" | |
] | |
} | |
], | |
"source": [ | |
"# https://github.com/locuslab/trellisnet/blob/master/TrellisNet/char_PTB/utils.py\n", | |
"# https://github.com/locuslab/trellisnet/blob/master/TrellisNet/char_PTB/data.py\n", | |
"# https://github.com/locuslab/trellisnet/blob/master/TrellisNet/char_PTB/char_ptb.py\n", | |
"# から適宜コピー\n", | |
"\n", | |
"import torch\n", | |
"import observations\n", | |
"import os\n", | |
"import pickle\n", | |
"from collections import Counter\n", | |
"\n", | |
"\n", | |
"class Dictionary(object):\n", | |
" def __init__(self):\n", | |
" self.char2idx = {}\n", | |
" self.idx2char = []\n", | |
" self.counter = Counter()\n", | |
"\n", | |
" def add_word(self, char):\n", | |
" self.counter[char] += 1\n", | |
"\n", | |
" def prep_dict(self):\n", | |
" for char in self.counter:\n", | |
" if char not in self.char2idx:\n", | |
" self.idx2char.append(char)\n", | |
" self.char2idx[char] = len(self.idx2char) - 1\n", | |
"\n", | |
" def __len__(self):\n", | |
" return len(self.idx2char)\n", | |
"\n", | |
"\n", | |
"class Corpus(object):\n", | |
" def __init__(self, string):\n", | |
" self.dictionary = Dictionary()\n", | |
" for c in string:\n", | |
" self.dictionary.add_word(c)\n", | |
" self.dictionary.prep_dict()\n", | |
"\n", | |
"\n", | |
"def char_tensor(corpus, string):\n", | |
" tensor = torch.zeros(len(string)).long()\n", | |
" for i in range(len(string)):\n", | |
" tensor[i] = corpus.dictionary.char2idx[string[i]]\n", | |
" return tensor\n", | |
"\n", | |
"\n", | |
"def batchify(data, batch_size):\n", | |
" nbatch = data.size(0) // batch_size\n", | |
" data = data.narrow(0, 0, nbatch * batch_size) # バッチサイズの整数倍をはみ出たら切り捨て\n", | |
" return data.view(batch_size, -1).t().contiguous() # Pytorch の RNN 式の次元の順序に転置\n", | |
"\n", | |
"\n", | |
"def get_batch(source, i, seq_len, evaluation=False):\n", | |
" seq_len = min(seq_len, source.size(0) - 1 - i)\n", | |
" data = source[i:i + seq_len]\n", | |
" if evaluation:\n", | |
" data.requires_grad = False\n", | |
" target = source[i + 1:i + 1 + seq_len] # CAUTION: This is un-flattened!\n", | |
" return data, target\n", | |
"\n", | |
"\n", | |
"def data_generator():\n", | |
" file, testfile, valfile = getattr(observations, 'ptb')('data/')\n", | |
" file, testfile, valfile = file.replace('<eos>', chr(255)), testfile.replace('<eos>', chr(255)), valfile.replace(\n", | |
" '<eos>', chr(255)) # Just replace <eos> with another unusual alphabet here (that is not in PTB)\n", | |
" file_len = len(file)\n", | |
" valfile_len = len(valfile)\n", | |
" testfile_len = len(testfile)\n", | |
" pickle_name = \"ptb.corpus\"\n", | |
" if os.path.exists(pickle_name):\n", | |
" print(\"Loading cached data...\")\n", | |
" corpus = pickle.load(open(pickle_name, 'rb'))\n", | |
" else:\n", | |
" corpus = Corpus(file + \" \" + valfile + \" \" + testfile)\n", | |
" pickle.dump(corpus, open(pickle_name, 'wb'))\n", | |
" return file, file_len, valfile, valfile_len, testfile, testfile_len, corpus\n", | |
"\n", | |
"\n", | |
"print('◆ PTBデータをロード')\n", | |
"batch_size = 24\n", | |
"file, file_len, valfile, valfile_len, testfile, testfile_len, corpus = data_generator()\n", | |
"train_data = batchify(char_tensor(corpus, file), batch_size)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"◆ インデックスと文字の対応\n", | |
"0 \n", | |
"1 a\n", | |
"2 e\n", | |
"3 r\n", | |
"4 b\n", | |
"5 n\n", | |
"6 k\n", | |
"7 o\n", | |
"8 t\n", | |
"9 l\n", | |
"\n", | |
"◆ 1バッチを取得\n", | |
"torch.Size([140, 24])\n", | |
"torch.Size([140, 24])\n", | |
"\n", | |
"◆ 1バッチの1つ目のデータを取得\n", | |
"◇ data\n", | |
"tensor([ 0, 1, 2, 3, 0, 4, 1, 5, 6, 5, 7, 8, 2, 0, 4, 2, 3, 9,\n", | |
" 10, 8, 11, 0, 12, 1, 9, 9, 7, 13, 1, 14, 0, 12, 2, 5, 8, 3,\n", | |
" 15, 16, 8, 0, 12, 9, 15, 2, 8, 8, 0, 17, 3, 7, 18, 16, 8, 2,\n", | |
" 10, 5, 0, 19, 10, 8, 1, 5, 7, 0, 19, 15, 8, 2, 3, 18, 1, 5,\n", | |
" 0, 20, 14, 21, 3, 7, 22, 23, 15, 2, 4, 2, 12, 0, 10, 24, 7, 0,\n", | |
" 6, 10, 1, 0, 18, 2, 18, 7, 8, 2, 12, 0, 18, 9, 25, 0, 5, 1,\n", | |
" 20, 4, 0, 24, 15, 5, 8, 16, 0, 3, 1, 6, 2, 0, 3, 2, 19, 1,\n", | |
" 8, 8, 1, 0, 3, 15, 4, 2, 5, 16, 0, 16, 10, 18])\n", | |
"◇ target\n", | |
"tensor([ 1, 2, 3, 0, 4, 1, 5, 6, 5, 7, 8, 2, 0, 4, 2, 3, 9, 10,\n", | |
" 8, 11, 0, 12, 1, 9, 9, 7, 13, 1, 14, 0, 12, 2, 5, 8, 3, 15,\n", | |
" 16, 8, 0, 12, 9, 15, 2, 8, 8, 0, 17, 3, 7, 18, 16, 8, 2, 10,\n", | |
" 5, 0, 19, 10, 8, 1, 5, 7, 0, 19, 15, 8, 2, 3, 18, 1, 5, 0,\n", | |
" 20, 14, 21, 3, 7, 22, 23, 15, 2, 4, 2, 12, 0, 10, 24, 7, 0, 6,\n", | |
" 10, 1, 0, 18, 2, 18, 7, 8, 2, 12, 0, 18, 9, 25, 0, 5, 1, 20,\n", | |
" 4, 0, 24, 15, 5, 8, 16, 0, 3, 1, 6, 2, 0, 3, 2, 19, 1, 8,\n", | |
" 8, 1, 0, 3, 15, 4, 2, 5, 16, 0, 16, 10, 18, 0])\n", | |
"◇ data (アルファベットに翻訳)\n", | |
" aer banknote berlitz calloway centrust cluett fromstein gitano guterman hydro-quebec ipo kia memotec mlx nahb punts rake regatta rubens sim\n", | |
"◇ target (アルファベットに翻訳)\n", | |
"aer banknote berlitz calloway centrust cluett fromstein gitano guterman hydro-quebec ipo kia memotec mlx nahb punts rake regatta rubens sim \n" | |
] | |
} | |
], | |
"source": [ | |
"print('◆ インデックスと文字の対応')\n", | |
"for i in range(10):\n", | |
" print(i, corpus.dictionary.idx2char[i])\n", | |
"\n", | |
"print('\\n◆ 1バッチを取得')\n", | |
"data, targets = get_batch(train_data, 0, 140)\n", | |
"print(data.size())\n", | |
"print(targets.size())\n", | |
"\n", | |
"print('\\n◆ 1バッチの1つ目のデータを取得')\n", | |
"print('◇ data')\n", | |
"print(data[:,0])\n", | |
"print('◇ target')\n", | |
"print(targets[:,0])\n", | |
"print('◇ data (アルファベットに翻訳)')\n", | |
"print(''.join([corpus.dictionary.idx2char[i.item()] for i in data[:,0]]))\n", | |
"print('◇ target (アルファベットに翻訳)')\n", | |
"print(''.join([corpus.dictionary.idx2char[i.item()] for i in targets[:,0]]))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### TrellisNet に適当なデータを流す\n", | |
"\n", | |
"※ 実際は文字を200次元に埋め込んだ上で流す。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"◆ TrellisNet(入出力が200次元、隠れ状態が1050次元)\n", | |
"Weight normalization applied\n", | |
"TrellisNet(\n", | |
" (full_conv): WeightShareConv1d(\n", | |
" (drop): VariationalHidDropout()\n", | |
" )\n", | |
")\n", | |
"\n", | |
"◆ 適当なデータを流す\n", | |
"◇ 入力\n", | |
"torch.Size([24, 200, 140])\n", | |
"torch.Size([24, 1250, 1])\n", | |
"torch.Size([24, 1250, 1])\n", | |
"◇ 出力\n", | |
"torch.Size([24, 1250, 1])\n", | |
"torch.Size([24, 1250, 1])\n", | |
"torch.Size([24, 2, 140, 200])\n" | |
] | |
} | |
], | |
"source": [ | |
"# https://github.com/locuslab/trellisnet/blob/master/TrellisNet/optimizations.py\n", | |
"# https://github.com/locuslab/trellisnet/blob/master/TrellisNet/trellisnet.py\n", | |
"# を横に配置\n", | |
"\n", | |
"import torch\n", | |
"from trellisnet import TrellisNet\n", | |
"\n", | |
"\n", | |
"print('◆ TrellisNet(入出力が200次元、隠れ状態が1050次元)')\n", | |
"model = TrellisNet(ninp=200, nhid=1050, nout=200)\n", | |
"print(model)\n", | |
"\n", | |
"print('\\n◆ 適当なデータを流す')\n", | |
"# X has dimension (N, ninp, L)\n", | |
"input_ = torch.rand(24, 200, 140)\n", | |
"h_0 = torch.rand(24, 1250, 1)\n", | |
"c_0 = torch.rand(24, 1250, 1)\n", | |
"print('◇ 入力')\n", | |
"print(input_.size())\n", | |
"print(h_0.size())\n", | |
"print(c_0.size())\n", | |
"out, hc, aux_outs = model(input_, (h_0, c_0))\n", | |
"print('◇ 出力')\n", | |
"print(hc[0].size())\n", | |
"print(hc[1].size())\n", | |
"print(aux_outs.size())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment