Created
February 28, 2022 00:19
-
-
Save shihono/646d11d176fef923a806e893706636af to your computer and use it in GitHub Desktop.
torchtext_pentreebank_gt.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "torchtext_pentreebank_gt.ipynb", | |
"provenance": [], | |
"toc_visible": true, | |
"authorship_tag": "ABX9TyMlVa5PZsSW2UIquH1Ob9xu", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/shihono/646d11d176fef923a806e893706636af/torchtext_pentreebank_gt.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import math\n", | |
"from collections import defaultdict\n", | |
"from itertools import islice\n", | |
"from pprint import pprint \n", | |
"\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"# pytorch 関連\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import torch.nn.functional as F\n", | |
"\n", | |
"import torchtext\n", | |
"from torchtext.data.utils import get_tokenizer\n", | |
"from torchtext.vocab import build_vocab_from_iterator" | |
], | |
"metadata": { | |
"id": "0gK3HKyKDGCZ" | |
}, | |
"execution_count": 1, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 前処理\n", | |
"\n", | |
"### Tokenizer, Vocab\n", | |
"\n" | |
], | |
"metadata": { | |
"id": "RoatFah_-ksB" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"tokenizer = get_tokenizer(None)\n", | |
"\n", | |
"def get_vocab(data_iter):\n", | |
" vocab = build_vocab_from_iterator(map(tokenizer, data_iter), specials=['<unk>'])\n", | |
" # 未知語は全部 `<unk>` 扱いにする\n", | |
" vocab.set_default_index(vocab['<unk>'])\n", | |
" return vocab" | |
], | |
"metadata": { | |
"id": "3lM26l2FDSRu" | |
}, | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### text2index" | |
], | |
"metadata": { | |
"id": "RnK87wqpFYve" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def text2index(text, vocab):\n", | |
" return vocab(tokenizer(text))" | |
], | |
"metadata": { | |
"id": "Ca0hdopXFU0K" | |
}, | |
"execution_count": 3, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## ヘルドアウト推定\n", | |
"\n", | |
"bigramで考える\n", | |
"\n", | |
"### 準備\n", | |
"\n", | |
"- 訓練データにおける頻度 C1(w1, ..., wn)\n", | |
"- ヘルドアウトデータにおける頻度 C2(w1, ..., wn)\n", | |
"- 頻度 r の n-gramのタイプ数 Nr\n", | |
"- 訓練テキスト中のn-gramの頻度がr回のn-gramがヘルドアウトデータで出現した数 Tr\n", | |
"\n", | |
"\n", | |
"train を半分に分割して 訓練データとヘルドアウトデータとする" | |
], | |
"metadata": { | |
"id": "QJ02t92KE5e3" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"def iter_ngram(items, n=2):\n", | |
" for idx in range(len(items)-n+1):\n", | |
" yield items[idx:idx+n]\n", | |
"\n", | |
"def get_ngram_count(data_iter, vocab, n=2):\n", | |
" ngram_count = defaultdict(int)\n", | |
" for data in data_iter:\n", | |
" indexes = text2index(data, vocab)\n", | |
" for ngram in iter_ngram(indexes, n=n):\n", | |
" ngram_count[tuple(ngram)] +=1\n", | |
" return dict(ngram_count)" | |
], | |
"metadata": { | |
"id": "5VusiyXrF6WW" | |
}, | |
"execution_count": 4, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"tmp = [data.strip() for data in torchtext.datasets.PennTreebank(split='train')]\n", | |
"\n", | |
"train_data = tmp[:len(tmp)//2]\n", | |
"heldout_data = tmp[len(tmp)//2:]\n", | |
"len(train_data), len(heldout_data)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "dkRypAuOQEbq", | |
"outputId": "f8027396-aec9-403b-a87e-ad1253715165" | |
}, | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"5.10MB [00:00, 19.7MB/s]\n" | |
] | |
}, | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"(21034, 21034)" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 5 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# 訓練データにおける頻度 c1\n", | |
"c1_vocab = get_vocab(train_data)\n", | |
"print(len(c1_vocab))\n", | |
"c1_bigram = get_ngram_count(train_data, c1_vocab, 2)\n", | |
"print(len(c1_bigram))\n", | |
"\n", | |
"# ヘルドアウトデータにおける頻度 c2\n", | |
"c2_bigram = get_ngram_count(heldout_data, c1_vocab, 2)\n", | |
"print(len(c2_bigram))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "1JY6Les5EwdW", | |
"outputId": "6ee4a258-2c51-4e21-86de-4cefcc66d460" | |
}, | |
"execution_count": 6, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"9661\n", | |
"157629\n", | |
"150070\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# 頻度 r の n-gramのタイプ数 Nr\n", | |
"nr_types = defaultdict(int)\n", | |
"for freq in c1_bigram.values():\n", | |
" nr_types[freq] += 1\n", | |
"\n", | |
"# r=0 は存在しないので vocab_sizeを利用\n", | |
"nr_types[0] = len(c1_vocab) * len(c1_vocab)" | |
], | |
"metadata": { | |
"id": "V4_woCaDHErZ" | |
}, | |
"execution_count": 7, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"for i in range(10):\n", | |
" print(i, nr_types[i])" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "GYgHhpAxIB2H", | |
"outputId": "edfdbc75-5727-4779-c0a0-a9ace7f6742f" | |
}, | |
"execution_count": 8, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"0 93334921\n", | |
"1 111266\n", | |
"2 20831\n", | |
"3 8446\n", | |
"4 4404\n", | |
"5 2769\n", | |
"6 2010\n", | |
"7 1304\n", | |
"8 911\n", | |
"9 710\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# Nはデータのサンプル数\n", | |
"n_sum = 0\n", | |
"for r, nr in nr_types.items():\n", | |
" n_sum += r*nr\n", | |
"\n", | |
"print(n_sum)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "_swx-wQlHl3S", | |
"outputId": "09c9cd00-9bd2-4909-997a-9aab5bedacfd" | |
}, | |
"execution_count": 9, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"426278\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# 訓練テキスト中のn-gramの頻度がr回のn-gramがヘルドアウトデータで出現した数 Tr\n", | |
"\n", | |
"tr_num = defaultdict(int)\n", | |
"t_sum = 0\n", | |
"\n", | |
"for ngram, r in c2_bigram.items():\n", | |
" c1_r = c1_bigram.get(ngram, 0)\n", | |
" tr_num[c1_r] += r\n", | |
" t_sum += r\n", | |
"\n", | |
"print(t_sum)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "HnsJFaVuIJSD", | |
"outputId": "0bb438cb-a2bf-4769-deb3-e2debc14547c" | |
}, | |
"execution_count": 10, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"419175\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"for i in range(10):\n", | |
" print(i, tr_num[i])" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "5C2K2DZ0JLyL", | |
"outputId": "7a786e19-1a84-41f2-9d63-817f34ed01dd" | |
}, | |
"execution_count": 11, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"0 114926\n", | |
"1 38323\n", | |
"2 23052\n", | |
"3 16426\n", | |
"4 12254\n", | |
"5 10315\n", | |
"6 9390\n", | |
"7 7274\n", | |
"8 5967\n", | |
"9 5315\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 確率推定値\n", | |
"\n", | |
"訓練データで r 回出現するngramについて、ヘルドアウトデータで合計 Tr 回出現する。\n", | |
"よって1つのngramは、r 回出現するngramのタイプ数で割った Tr/Nr 回出現することになる。\n", | |
"\n", | |
"よって確率推定値はヘルドアウトの合計 T で割った以下\n", | |
"\n", | |
"$ P_{ho}(w_1,...,w_n) = \\frac{T_r}{N_r T} $\n", | |
"\n", | |
"- `w_1,...,w_n` の学習データでの頻度が r \n", | |
"- T は T_r の総和" | |
], | |
"metadata": { | |
"id": "3z5etTO7Jd50" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"for i in range(10):\n", | |
" print(f\"{i}\\t{nr_types[i]}\\t{tr_num[i]}\\t{tr_num[i]/nr_types[i]:3f}\\t{-math.log2(tr_num[i]/(nr_types[i]*t_sum)):.3f}\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "Ei0BrJQBVCiB", | |
"outputId": "014df6cd-817c-465a-812d-bf6e243f876c" | |
}, | |
"execution_count": 12, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"0\t93334921\t114926\t0.001231\t28.343\n", | |
"1\t111266\t38323\t0.344427\t20.215\n", | |
"2\t20831\t23052\t1.106620\t18.531\n", | |
"3\t8446\t16426\t1.944826\t17.718\n", | |
"4\t4404\t12254\t2.782470\t17.201\n", | |
"5\t2769\t10315\t3.725172\t16.780\n", | |
"6\t2010\t9390\t4.671642\t16.453\n", | |
"7\t1304\t7274\t5.578221\t16.197\n", | |
"8\t911\t5967\t6.549945\t15.966\n", | |
"9\t710\t5315\t7.485915\t15.773\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"vocab_itos = c1_vocab.get_itos()\n", | |
"\n", | |
"for data in islice(torchtext.datasets.PennTreebank(split='valid'), 1):\n", | |
" print(data)\n", | |
" indexes = text2index(data, c1_vocab)\n", | |
" for ngram in iter_ngram(indexes):\n", | |
" r = c1_bigram.get(tuple(ngram), 0)\n", | |
" nr = nr_types[r]\n", | |
" tr = tr_num.get(r, 0)\n", | |
" print(ngram, [vocab_itos[i] for i in ngram], math.log2(tr/(nr*t_sum)))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "x9vUHKvVJZsb", | |
"outputId": "ca66eeb7-5656-487b-e30c-85541c658eab" | |
}, | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"400kB [00:00, 17.1MB/s] " | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
" consumers may want to move their telephones a little closer to the tv set \n", | |
"\n", | |
"[1104, 89] ['consumers', 'may'] -28.342761082872915\n", | |
"[89, 372] ['may', 'want'] -18.53103332348117\n", | |
"[372, 4] ['want', 'to'] -12.312620717862819\n", | |
"[4, 274] ['to', 'move'] -14.307959340492955\n", | |
"[274, 51] ['move', 'their'] -28.342761082872915\n", | |
"[51, 8572] ['their', 'telephones'] -28.342761082872915\n", | |
"[8572, 5] ['telephones', 'a'] -28.342761082872915\n", | |
"[5, 308] ['a', 'little'] -13.547910133213708\n", | |
"[308, 2308] ['little', 'closer'] -28.342761082872915\n", | |
"[2308, 4] ['closer', 'to'] -15.022041334722774\n", | |
"[4, 1] ['to', 'the'] -8.97848848338833\n", | |
"[1, 736] ['the', 'tv'] -15.633596761006274\n", | |
"[736, 396] ['tv', 'set'] -28.342761082872915\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"sample_text = \"this wikipedia is written in english\"\n", | |
"tokens = text2index(sample_text, c1_vocab)\n", | |
"for t in iter_ngram(tokens, 2):\n", | |
" r = c1_bigram.get(tuple(t),0)\n", | |
" nr = nr_types[r]\n", | |
" tr = tr_num.get(r, 0)\n", | |
" print(t, [vocab_itos[i] for i in t], r, -math.log2(tr/(nr*t_sum)))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "_qKx8k3Dk47C", | |
"outputId": "0a1f71c1-66d0-4afb-c92c-bc84fdf7cf0d" | |
}, | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"[36, 0] ['this', '<unk>'] 51 13.057139984969156\n", | |
"[0, 11] ['<unk>', 'is'] 212 10.867692956269513\n", | |
"[11, 1932] ['is', 'written'] 0 28.342761082872915\n", | |
"[1932, 6] ['written', 'in'] 4 17.200826767745713\n", | |
"[6, 2404] ['in', 'english'] 5 16.779886287751758\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### エントロピー" | |
], | |
"metadata": { | |
"id": "RjRfnvezN7rY" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# エントロピーを求める\n", | |
"word_sum = 0\n", | |
"h = 0\n", | |
"\n", | |
"test_iter = torchtext.datasets.WikiText2(split='test')\n", | |
"for t in test_iter:\n", | |
" tokens = text2index(t, c1_vocab)\n", | |
" word_sum += len(tokens)\n", | |
" for ngram in iter_ngram(indexes):\n", | |
" r = c1_bigram.get(tuple(ngram), 0)\n", | |
" nr = nr_types[r]\n", | |
" tr = tr_num.get(r, 0)\n", | |
" p = tr/(nr * t_sum)\n", | |
" h += -math.log2(p)\n", | |
"\n", | |
"print(f\"word_sum: {word_sum}\\nentropy: {h/word_sum}\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "ovPwK7QyN91D", | |
"outputId": "7576e51c-51cc-4cc3-816f-c5f74d86d7f1" | |
}, | |
"execution_count": 15, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stderr", | |
"text": [ | |
"100%|██████████| 4.48M/4.48M [00:01<00:00, 2.72MB/s]\n" | |
] | |
}, | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"word_sum: 241211\n", | |
"entropy: 4.849051510527865\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## Good-Turing estimator\n", | |
"\n", | |
"確率推定値は\n", | |
"\n", | |
"$ P_{GT} = \\frac{r^{*}}{N} $\n", | |
"\n", | |
"r* は推定頻度\n", | |
"\n", | |
"$ r^* = (r+1) \\frac{N_{r+1}}{N_r} $ \n", | |
"\n" | |
], | |
"metadata": { | |
"id": "9hgzY5eIPGEy" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"train_iter = torchtext.datasets.PennTreebank(split='train')\n", | |
"train_vocab = get_vocab(train_iter)\n", | |
"print(len(train_vocab))\n", | |
"\n", | |
"train_iter = torchtext.datasets.PennTreebank(split='train')\n", | |
"gt_unigram = get_ngram_count(train_iter, train_vocab)\n", | |
"\n", | |
"gt_nr = defaultdict(int)\n", | |
"for r in gt_unigram.values():\n", | |
" gt_nr[r] += 1\n", | |
"\n", | |
"gt_n_sum = sum([r*nr for r, nr in gt_nr.items()])\n", | |
"print(gt_n_sum)" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "zd-ffsSZbE7s", | |
"outputId": "fb89d39b-0fcc-4a9f-cd36-53fa0626438e" | |
}, | |
"execution_count": 16, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"9999\n", | |
"845453\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"max(list(gt_nr.keys()))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "18JryZY_c0lT", | |
"outputId": "d1f89634-25bb-4e93-f263-fd985e9af1ed" | |
}, | |
"execution_count": 17, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"7475" | |
] | |
}, | |
"metadata": {}, | |
"execution_count": 17 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"print(f\"{0}\\t{1*gt_nr[1]/ gt_n_sum:.4f}\")\n", | |
"for i in range(1, 10):\n", | |
" print(f\"{i}\\t{(i+1)*gt_nr[i+1]/gt_nr[i]:.4f}\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "iGM5CGNYcR6_", | |
"outputId": "1495560f-f4a2-4eb1-ad4d-bff2493b8a43" | |
}, | |
"execution_count": 18, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"0\t0.2050\n", | |
"1\t0.4020\n", | |
"2\t1.2747\n", | |
"3\t2.1785\n", | |
"4\t3.1855\n", | |
"5\t4.0148\n", | |
"6\t5.3132\n", | |
"7\t5.6534\n", | |
"8\t7.1004\n", | |
"9\t8.5488\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# N_r の r が大きくなると、 N_r もしくは N_r+1 がなくなる\n", | |
"\n", | |
"print(\"r, N_r, N_r+1\")\n", | |
"for r in range(1, 300):\n", | |
" if gt_nr.get(r) is not None:\n", | |
" continue\n", | |
" else:\n", | |
" print(f\"{r}, {gt_nr.get(r)}, {gt_nr.get(r+1)}\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "u1-uX8GOgBPK", | |
"outputId": "27374e9e-6ad7-4062-8ea0-e5786005e2ea" | |
}, | |
"execution_count": 19, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"r, N_r, N_r+1\n", | |
"125, None, 4\n", | |
"136, None, 2\n", | |
"163, None, 5\n", | |
"171, None, 3\n", | |
"185, None, 2\n", | |
"189, None, 2\n", | |
"200, None, 1\n", | |
"208, None, None\n", | |
"209, None, 4\n", | |
"212, None, 1\n", | |
"215, None, 3\n", | |
"217, None, 1\n", | |
"220, None, 2\n", | |
"227, None, 1\n", | |
"231, None, 1\n", | |
"236, None, 1\n", | |
"238, None, 1\n", | |
"240, None, None\n", | |
"241, None, None\n", | |
"242, None, 1\n", | |
"244, None, 1\n", | |
"248, None, 1\n", | |
"257, None, None\n", | |
"258, None, 2\n", | |
"260, None, None\n", | |
"261, None, 1\n", | |
"265, None, 2\n", | |
"270, None, None\n", | |
"271, None, None\n", | |
"272, None, 1\n", | |
"274, None, 2\n", | |
"280, None, None\n", | |
"281, None, None\n", | |
"282, None, 1\n", | |
"286, None, 1\n", | |
"288, None, None\n", | |
"289, None, None\n", | |
"290, None, None\n", | |
"291, None, 1\n", | |
"294, None, 2\n", | |
"296, None, None\n", | |
"297, None, 1\n" | |
] | |
} | |
] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment