Skip to content

Instantly share code, notes, and snippets.

@CookieBox26
Created September 10, 2020 14:56
Show Gist options
  • Save CookieBox26/b2254f1125f8849cca3f5c12b0dcaf61 to your computer and use it in GitHub Desktop.
Save CookieBox26/b2254f1125f8849cca3f5c12b0dcaf61 to your computer and use it in GitHub Desktop.
文字レベル Penn Treebank と TrellisNet
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 文字レベル Penn Treebank と TrellisNet\n",
"\n",
"### 参考文献\n",
"1. https://arxiv.org/abs/1810.06682 ;TrellisNet の原論文。 \n",
"2. https://github.com/locuslab/trellisnet ;TrellisNet のリポジトリ。 "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 文字レベル Penn Treebank 用のデータを読み込む"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ PTBデータをロード\n",
"Loading cached data...\n"
]
}
],
"source": [
"# https://github.com/locuslab/trellisnet/blob/master/TrellisNet/char_PTB/utils.py\n",
"# https://github.com/locuslab/trellisnet/blob/master/TrellisNet/char_PTB/data.py\n",
"# https://github.com/locuslab/trellisnet/blob/master/TrellisNet/char_PTB/char_ptb.py\n",
"# から適宜コピー\n",
"\n",
"import torch\n",
"import observations\n",
"import os\n",
"import pickle\n",
"from collections import Counter\n",
"\n",
"\n",
"class Dictionary(object):\n",
" def __init__(self):\n",
" self.char2idx = {}\n",
" self.idx2char = []\n",
" self.counter = Counter()\n",
"\n",
" def add_word(self, char):\n",
" self.counter[char] += 1\n",
"\n",
" def prep_dict(self):\n",
" for char in self.counter:\n",
" if char not in self.char2idx:\n",
" self.idx2char.append(char)\n",
" self.char2idx[char] = len(self.idx2char) - 1\n",
"\n",
" def __len__(self):\n",
" return len(self.idx2char)\n",
"\n",
"\n",
"class Corpus(object):\n",
" def __init__(self, string):\n",
" self.dictionary = Dictionary()\n",
" for c in string:\n",
" self.dictionary.add_word(c)\n",
" self.dictionary.prep_dict()\n",
"\n",
"\n",
"def char_tensor(corpus, string):\n",
" tensor = torch.zeros(len(string)).long()\n",
" for i in range(len(string)):\n",
" tensor[i] = corpus.dictionary.char2idx[string[i]]\n",
" return tensor\n",
"\n",
"\n",
"def batchify(data, batch_size):\n",
" nbatch = data.size(0) // batch_size\n",
" data = data.narrow(0, 0, nbatch * batch_size) # バッチサイズの整数倍をはみ出たら切り捨て\n",
" return data.view(batch_size, -1).t().contiguous() # Pytorch の RNN 式の次元の順序に転置\n",
"\n",
"\n",
"def get_batch(source, i, seq_len, evaluation=False):\n",
" seq_len = min(seq_len, source.size(0) - 1 - i)\n",
" data = source[i:i + seq_len]\n",
" if evaluation:\n",
" data.requires_grad = False\n",
" target = source[i + 1:i + 1 + seq_len] # CAUTION: This is un-flattened!\n",
" return data, target\n",
"\n",
"\n",
"def data_generator():\n",
" file, testfile, valfile = getattr(observations, 'ptb')('data/')\n",
" file, testfile, valfile = file.replace('<eos>', chr(255)), testfile.replace('<eos>', chr(255)), valfile.replace(\n",
" '<eos>', chr(255)) # Just replace <eos> with another unusual alphabet here (that is not in PTB)\n",
" file_len = len(file)\n",
" valfile_len = len(valfile)\n",
" testfile_len = len(testfile)\n",
" pickle_name = \"ptb.corpus\"\n",
" if os.path.exists(pickle_name):\n",
" print(\"Loading cached data...\")\n",
" corpus = pickle.load(open(pickle_name, 'rb'))\n",
" else:\n",
" corpus = Corpus(file + \" \" + valfile + \" \" + testfile)\n",
" pickle.dump(corpus, open(pickle_name, 'wb'))\n",
" return file, file_len, valfile, valfile_len, testfile, testfile_len, corpus\n",
"\n",
"\n",
"print('◆ PTBデータをロード')\n",
"batch_size = 24\n",
"file, file_len, valfile, valfile_len, testfile, testfile_len, corpus = data_generator()\n",
"train_data = batchify(char_tensor(corpus, file), batch_size)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ インデックスと文字の対応\n",
"0 \n",
"1 a\n",
"2 e\n",
"3 r\n",
"4 b\n",
"5 n\n",
"6 k\n",
"7 o\n",
"8 t\n",
"9 l\n",
"\n",
"◆ 1バッチを取得\n",
"torch.Size([140, 24])\n",
"torch.Size([140, 24])\n",
"\n",
"◆ 1バッチの1つ目のデータを取得\n",
"◇ data\n",
"tensor([ 0, 1, 2, 3, 0, 4, 1, 5, 6, 5, 7, 8, 2, 0, 4, 2, 3, 9,\n",
" 10, 8, 11, 0, 12, 1, 9, 9, 7, 13, 1, 14, 0, 12, 2, 5, 8, 3,\n",
" 15, 16, 8, 0, 12, 9, 15, 2, 8, 8, 0, 17, 3, 7, 18, 16, 8, 2,\n",
" 10, 5, 0, 19, 10, 8, 1, 5, 7, 0, 19, 15, 8, 2, 3, 18, 1, 5,\n",
" 0, 20, 14, 21, 3, 7, 22, 23, 15, 2, 4, 2, 12, 0, 10, 24, 7, 0,\n",
" 6, 10, 1, 0, 18, 2, 18, 7, 8, 2, 12, 0, 18, 9, 25, 0, 5, 1,\n",
" 20, 4, 0, 24, 15, 5, 8, 16, 0, 3, 1, 6, 2, 0, 3, 2, 19, 1,\n",
" 8, 8, 1, 0, 3, 15, 4, 2, 5, 16, 0, 16, 10, 18])\n",
"◇ target\n",
"tensor([ 1, 2, 3, 0, 4, 1, 5, 6, 5, 7, 8, 2, 0, 4, 2, 3, 9, 10,\n",
" 8, 11, 0, 12, 1, 9, 9, 7, 13, 1, 14, 0, 12, 2, 5, 8, 3, 15,\n",
" 16, 8, 0, 12, 9, 15, 2, 8, 8, 0, 17, 3, 7, 18, 16, 8, 2, 10,\n",
" 5, 0, 19, 10, 8, 1, 5, 7, 0, 19, 15, 8, 2, 3, 18, 1, 5, 0,\n",
" 20, 14, 21, 3, 7, 22, 23, 15, 2, 4, 2, 12, 0, 10, 24, 7, 0, 6,\n",
" 10, 1, 0, 18, 2, 18, 7, 8, 2, 12, 0, 18, 9, 25, 0, 5, 1, 20,\n",
" 4, 0, 24, 15, 5, 8, 16, 0, 3, 1, 6, 2, 0, 3, 2, 19, 1, 8,\n",
" 8, 1, 0, 3, 15, 4, 2, 5, 16, 0, 16, 10, 18, 0])\n",
"◇ data (アルファベットに翻訳)\n",
" aer banknote berlitz calloway centrust cluett fromstein gitano guterman hydro-quebec ipo kia memotec mlx nahb punts rake regatta rubens sim\n",
"◇ target (アルファベットに翻訳)\n",
"aer banknote berlitz calloway centrust cluett fromstein gitano guterman hydro-quebec ipo kia memotec mlx nahb punts rake regatta rubens sim \n"
]
}
],
"source": [
"print('◆ インデックスと文字の対応')\n",
"for i in range(10):\n",
" print(i, corpus.dictionary.idx2char[i])\n",
"\n",
"print('\\n◆ 1バッチを取得')\n",
"data, targets = get_batch(train_data, 0, 140)\n",
"print(data.size())\n",
"print(targets.size())\n",
"\n",
"print('\\n◆ 1バッチの1つ目のデータを取得')\n",
"print('◇ data')\n",
"print(data[:,0])\n",
"print('◇ target')\n",
"print(targets[:,0])\n",
"print('◇ data (アルファベットに翻訳)')\n",
"print(''.join([corpus.dictionary.idx2char[i.item()] for i in data[:,0]]))\n",
"print('◇ target (アルファベットに翻訳)')\n",
"print(''.join([corpus.dictionary.idx2char[i.item()] for i in targets[:,0]]))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### TrellisNet に適当なデータを流す\n",
"\n",
"※ 実際は文字を200次元に埋め込んだ上で流す。"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"◆ TrellisNet(入出力が200次元、隠れ状態が1050次元)\n",
"Weight normalization applied\n",
"TrellisNet(\n",
" (full_conv): WeightShareConv1d(\n",
" (drop): VariationalHidDropout()\n",
" )\n",
")\n",
"\n",
"◆ 適当なデータを流す\n",
"◇ 入力\n",
"torch.Size([24, 200, 140])\n",
"torch.Size([24, 1250, 1])\n",
"torch.Size([24, 1250, 1])\n",
"◇ 出力\n",
"torch.Size([24, 1250, 1])\n",
"torch.Size([24, 1250, 1])\n",
"torch.Size([24, 2, 140, 200])\n"
]
}
],
"source": [
"# https://github.com/locuslab/trellisnet/blob/master/TrellisNet/optimizations.py\n",
"# https://github.com/locuslab/trellisnet/blob/master/TrellisNet/trellisnet.py\n",
"# を横に配置\n",
"\n",
"import torch\n",
"from trellisnet import TrellisNet\n",
"\n",
"\n",
"print('◆ TrellisNet(入出力が200次元、隠れ状態が1050次元)')\n",
"model = TrellisNet(ninp=200, nhid=1050, nout=200)\n",
"print(model)\n",
"\n",
"print('\\n◆ 適当なデータを流す')\n",
"# X has dimension (N, ninp, L)\n",
"input_ = torch.rand(24, 200, 140)\n",
"h_0 = torch.rand(24, 1250, 1)\n",
"c_0 = torch.rand(24, 1250, 1)\n",
"print('◇ 入力')\n",
"print(input_.size())\n",
"print(h_0.size())\n",
"print(c_0.size())\n",
"out, hc, aux_outs = model(input_, (h_0, c_0))\n",
"print('◇ 出力')\n",
"print(hc[0].size())\n",
"print(hc[1].size())\n",
"print(aux_outs.size())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment