Skip to content

Instantly share code, notes, and snippets.

@shihono
Created February 28, 2022 00:19
Show Gist options
  • Save shihono/646d11d176fef923a806e893706636af to your computer and use it in GitHub Desktop.
Save shihono/646d11d176fef923a806e893706636af to your computer and use it in GitHub Desktop.
torchtext_pentreebank_gt.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "torchtext_pentreebank_gt.ipynb",
"provenance": [],
"toc_visible": true,
"authorship_tag": "ABX9TyMlVa5PZsSW2UIquH1Ob9xu",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/shihono/646d11d176fef923a806e893706636af/torchtext_pentreebank_gt.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"source": [
"import math\n",
"from collections import defaultdict\n",
"from itertools import islice\n",
"from pprint import pprint \n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# pytorch 関連\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"import torchtext\n",
"from torchtext.data.utils import get_tokenizer\n",
"from torchtext.vocab import build_vocab_from_iterator"
],
"metadata": {
"id": "0gK3HKyKDGCZ"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## 前処理\n",
"\n",
"### Tokenizer, Vocab\n",
"\n"
],
"metadata": {
"id": "RoatFah_-ksB"
}
},
{
"cell_type": "code",
"source": [
"tokenizer = get_tokenizer(None)\n",
"\n",
"def get_vocab(data_iter):\n",
" vocab = build_vocab_from_iterator(map(tokenizer, data_iter), specials=['<unk>'])\n",
" # 未知語は全部 `<unk>` 扱いにする\n",
" vocab.set_default_index(vocab['<unk>'])\n",
" return vocab"
],
"metadata": {
"id": "3lM26l2FDSRu"
},
"execution_count": 2,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### text2index"
],
"metadata": {
"id": "RnK87wqpFYve"
}
},
{
"cell_type": "code",
"source": [
"def text2index(text, vocab):\n",
" return vocab(tokenizer(text))"
],
"metadata": {
"id": "Ca0hdopXFU0K"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## ヘルドアウト推定\n",
"\n",
"bigramで考える\n",
"\n",
"### 準備\n",
"\n",
"- 訓練データにおける頻度 C1(w1, ..., wn)\n",
"- ヘルドアウトデータにおける頻度 C2(w1, ..., wn)\n",
"- 頻度 r の n-gramのタイプ数 Nr\n",
"- 訓練テキスト中のn-gramの頻度がr回のn-gramがヘルドアウトデータで出現した数 Tr\n",
"\n",
"\n",
"train を半分に分割して 訓練データとヘルドアウトデータとする"
],
"metadata": {
"id": "QJ02t92KE5e3"
}
},
{
"cell_type": "code",
"source": [
"def iter_ngram(items, n=2):\n",
" for idx in range(len(items)-n+1):\n",
" yield items[idx:idx+n]\n",
"\n",
"def get_ngram_count(data_iter, vocab, n=2):\n",
" ngram_count = defaultdict(int)\n",
" for data in data_iter:\n",
" indexes = text2index(data, vocab)\n",
" for ngram in iter_ngram(indexes, n=n):\n",
" ngram_count[tuple(ngram)] +=1\n",
" return dict(ngram_count)"
],
"metadata": {
"id": "5VusiyXrF6WW"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"tmp = [data.strip() for data in torchtext.datasets.PennTreebank(split='train')]\n",
"\n",
"train_data = tmp[:len(tmp)//2]\n",
"heldout_data = tmp[len(tmp)//2:]\n",
"len(train_data), len(heldout_data)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dkRypAuOQEbq",
"outputId": "f8027396-aec9-403b-a87e-ad1253715165"
},
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"5.10MB [00:00, 19.7MB/s]\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(21034, 21034)"
]
},
"metadata": {},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"source": [
"# 訓練データにおける頻度 c1\n",
"c1_vocab = get_vocab(train_data)\n",
"print(len(c1_vocab))\n",
"c1_bigram = get_ngram_count(train_data, c1_vocab, 2)\n",
"print(len(c1_bigram))\n",
"\n",
"# ヘルドアウトデータにおける頻度 c2\n",
"c2_bigram = get_ngram_count(heldout_data, c1_vocab, 2)\n",
"print(len(c2_bigram))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1JY6Les5EwdW",
"outputId": "6ee4a258-2c51-4e21-86de-4cefcc66d460"
},
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"9661\n",
"157629\n",
"150070\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# 頻度 r の n-gramのタイプ数 Nr\n",
"nr_types = defaultdict(int)\n",
"for freq in c1_bigram.values():\n",
" nr_types[freq] += 1\n",
"\n",
"# r=0 は存在しないので vocab_sizeを利用\n",
"nr_types[0] = len(c1_vocab) * len(c1_vocab)"
],
"metadata": {
"id": "V4_woCaDHErZ"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"source": [
"for i in range(10):\n",
" print(i, nr_types[i])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GYgHhpAxIB2H",
"outputId": "edfdbc75-5727-4779-c0a0-a9ace7f6742f"
},
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0 93334921\n",
"1 111266\n",
"2 20831\n",
"3 8446\n",
"4 4404\n",
"5 2769\n",
"6 2010\n",
"7 1304\n",
"8 911\n",
"9 710\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Nはデータのサンプル数\n",
"n_sum = 0\n",
"for r, nr in nr_types.items():\n",
" n_sum += r*nr\n",
"\n",
"print(n_sum)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "_swx-wQlHl3S",
"outputId": "09c9cd00-9bd2-4909-997a-9aab5bedacfd"
},
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"426278\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# 訓練テキスト中のn-gramの頻度がr回のn-gramがヘルドアウトデータで出現した数 Tr\n",
"\n",
"tr_num = defaultdict(int)\n",
"t_sum = 0\n",
"\n",
"for ngram, r in c2_bigram.items():\n",
" c1_r = c1_bigram.get(ngram, 0)\n",
" tr_num[c1_r] += r\n",
" t_sum += r\n",
"\n",
"print(t_sum)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "HnsJFaVuIJSD",
"outputId": "0bb438cb-a2bf-4769-deb3-e2debc14547c"
},
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"419175\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"for i in range(10):\n",
" print(i, tr_num[i])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5C2K2DZ0JLyL",
"outputId": "7a786e19-1a84-41f2-9d63-817f34ed01dd"
},
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0 114926\n",
"1 38323\n",
"2 23052\n",
"3 16426\n",
"4 12254\n",
"5 10315\n",
"6 9390\n",
"7 7274\n",
"8 5967\n",
"9 5315\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"### 確率推定値\n",
"\n",
"訓練データで r 回出現するngramについて、ヘルドアウトデータで合計 Tr 回出現する。\n",
"よって1つのngramは、r 回出現するngramのタイプ数で割った Tr/Nr 回出現することになる。\n",
"\n",
"よって確率推定値はヘルドアウトの合計 T で割った以下\n",
"\n",
"$ P_{ho}(w_1,...,w_n) = \\frac{T_r}{N_r T} $\n",
"\n",
"- `w_1,...,w_n` の学習データでの頻度が r \n",
"- T は T_r の総和"
],
"metadata": {
"id": "3z5etTO7Jd50"
}
},
{
"cell_type": "code",
"source": [
"for i in range(10):\n",
" print(f\"{i}\\t{nr_types[i]}\\t{tr_num[i]}\\t{tr_num[i]/nr_types[i]:3f}\\t{-math.log2(tr_num[i]/(nr_types[i]*t_sum)):.3f}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Ei0BrJQBVCiB",
"outputId": "014df6cd-817c-465a-812d-bf6e243f876c"
},
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0\t93334921\t114926\t0.001231\t28.343\n",
"1\t111266\t38323\t0.344427\t20.215\n",
"2\t20831\t23052\t1.106620\t18.531\n",
"3\t8446\t16426\t1.944826\t17.718\n",
"4\t4404\t12254\t2.782470\t17.201\n",
"5\t2769\t10315\t3.725172\t16.780\n",
"6\t2010\t9390\t4.671642\t16.453\n",
"7\t1304\t7274\t5.578221\t16.197\n",
"8\t911\t5967\t6.549945\t15.966\n",
"9\t710\t5315\t7.485915\t15.773\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"vocab_itos = c1_vocab.get_itos()\n",
"\n",
"for data in islice(torchtext.datasets.PennTreebank(split='valid'), 1):\n",
" print(data)\n",
" indexes = text2index(data, c1_vocab)\n",
" for ngram in iter_ngram(indexes):\n",
" r = c1_bigram.get(tuple(ngram), 0)\n",
" nr = nr_types[r]\n",
" tr = tr_num.get(r, 0)\n",
" print(ngram, [vocab_itos[i] for i in ngram], math.log2(tr/(nr*t_sum)))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "x9vUHKvVJZsb",
"outputId": "ca66eeb7-5656-487b-e30c-85541c658eab"
},
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"400kB [00:00, 17.1MB/s] "
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
" consumers may want to move their telephones a little closer to the tv set \n",
"\n",
"[1104, 89] ['consumers', 'may'] -28.342761082872915\n",
"[89, 372] ['may', 'want'] -18.53103332348117\n",
"[372, 4] ['want', 'to'] -12.312620717862819\n",
"[4, 274] ['to', 'move'] -14.307959340492955\n",
"[274, 51] ['move', 'their'] -28.342761082872915\n",
"[51, 8572] ['their', 'telephones'] -28.342761082872915\n",
"[8572, 5] ['telephones', 'a'] -28.342761082872915\n",
"[5, 308] ['a', 'little'] -13.547910133213708\n",
"[308, 2308] ['little', 'closer'] -28.342761082872915\n",
"[2308, 4] ['closer', 'to'] -15.022041334722774\n",
"[4, 1] ['to', 'the'] -8.97848848338833\n",
"[1, 736] ['the', 'tv'] -15.633596761006274\n",
"[736, 396] ['tv', 'set'] -28.342761082872915\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"sample_text = \"this wikipedia is written in english\"\n",
"tokens = text2index(sample_text, c1_vocab)\n",
"for t in iter_ngram(tokens, 2):\n",
" r = c1_bigram.get(tuple(t),0)\n",
" nr = nr_types[r]\n",
" tr = tr_num.get(r, 0)\n",
" print(t, [vocab_itos[i] for i in t], r, -math.log2(tr/(nr*t_sum)))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "_qKx8k3Dk47C",
"outputId": "0a1f71c1-66d0-4afb-c92c-bc84fdf7cf0d"
},
"execution_count": 14,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[36, 0] ['this', '<unk>'] 51 13.057139984969156\n",
"[0, 11] ['<unk>', 'is'] 212 10.867692956269513\n",
"[11, 1932] ['is', 'written'] 0 28.342761082872915\n",
"[1932, 6] ['written', 'in'] 4 17.200826767745713\n",
"[6, 2404] ['in', 'english'] 5 16.779886287751758\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"### エントロピー"
],
"metadata": {
"id": "RjRfnvezN7rY"
}
},
{
"cell_type": "code",
"source": [
"# エントロピーを求める\n",
"word_sum = 0\n",
"h = 0\n",
"\n",
"test_iter = torchtext.datasets.WikiText2(split='test')\n",
"for t in test_iter:\n",
" tokens = text2index(t, c1_vocab)\n",
" word_sum += len(tokens)\n",
" for ngram in iter_ngram(indexes):\n",
" r = c1_bigram.get(tuple(ngram), 0)\n",
" nr = nr_types[r]\n",
" tr = tr_num.get(r, 0)\n",
" p = tr/(nr * t_sum)\n",
" h += -math.log2(p)\n",
"\n",
"print(f\"word_sum: {word_sum}\\nentropy: {h/word_sum}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ovPwK7QyN91D",
"outputId": "7576e51c-51cc-4cc3-816f-c5f74d86d7f1"
},
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"100%|██████████| 4.48M/4.48M [00:01<00:00, 2.72MB/s]\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"word_sum: 241211\n",
"entropy: 4.849051510527865\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"## Good-Turing estimator\n",
"\n",
"確率推定値は\n",
"\n",
"$ P_{GT} = \\frac{r^{*}}{N} $\n",
"\n",
"r* は推定頻度\n",
"\n",
"$ r^* = (r+1) \\frac{N_{r+1}}{N_r} $ \n",
"\n"
],
"metadata": {
"id": "9hgzY5eIPGEy"
}
},
{
"cell_type": "code",
"source": [
"train_iter = torchtext.datasets.PennTreebank(split='train')\n",
"train_vocab = get_vocab(train_iter)\n",
"print(len(train_vocab))\n",
"\n",
"train_iter = torchtext.datasets.PennTreebank(split='train')\n",
"gt_unigram = get_ngram_count(train_iter, train_vocab)\n",
"\n",
"gt_nr = defaultdict(int)\n",
"for r in gt_unigram.values():\n",
" gt_nr[r] += 1\n",
"\n",
"gt_n_sum = sum([r*nr for r, nr in gt_nr.items()])\n",
"print(gt_n_sum)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "zd-ffsSZbE7s",
"outputId": "fb89d39b-0fcc-4a9f-cd36-53fa0626438e"
},
"execution_count": 16,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"9999\n",
"845453\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"max(list(gt_nr.keys()))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "18JryZY_c0lT",
"outputId": "d1f89634-25bb-4e93-f263-fd985e9af1ed"
},
"execution_count": 17,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"7475"
]
},
"metadata": {},
"execution_count": 17
}
]
},
{
"cell_type": "code",
"source": [
"print(f\"{0}\\t{1*gt_nr[1]/ gt_n_sum:.4f}\")\n",
"for i in range(1, 10):\n",
" print(f\"{i}\\t{(i+1)*gt_nr[i+1]/gt_nr[i]:.4f}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "iGM5CGNYcR6_",
"outputId": "1495560f-f4a2-4eb1-ad4d-bff2493b8a43"
},
"execution_count": 18,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0\t0.2050\n",
"1\t0.4020\n",
"2\t1.2747\n",
"3\t2.1785\n",
"4\t3.1855\n",
"5\t4.0148\n",
"6\t5.3132\n",
"7\t5.6534\n",
"8\t7.1004\n",
"9\t8.5488\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# N_r の r が大きくなると、 N_r もしくは N_r+1 がなくなる\n",
"\n",
"print(\"r, N_r, N_r+1\")\n",
"for r in range(1, 300):\n",
" if gt_nr.get(r) is not None:\n",
" continue\n",
" else:\n",
" print(f\"{r}, {gt_nr.get(r)}, {gt_nr.get(r+1)}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "u1-uX8GOgBPK",
"outputId": "27374e9e-6ad7-4062-8ea0-e5786005e2ea"
},
"execution_count": 19,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"r, N_r, N_r+1\n",
"125, None, 4\n",
"136, None, 2\n",
"163, None, 5\n",
"171, None, 3\n",
"185, None, 2\n",
"189, None, 2\n",
"200, None, 1\n",
"208, None, None\n",
"209, None, 4\n",
"212, None, 1\n",
"215, None, 3\n",
"217, None, 1\n",
"220, None, 2\n",
"227, None, 1\n",
"231, None, 1\n",
"236, None, 1\n",
"238, None, 1\n",
"240, None, None\n",
"241, None, None\n",
"242, None, 1\n",
"244, None, 1\n",
"248, None, 1\n",
"257, None, None\n",
"258, None, 2\n",
"260, None, None\n",
"261, None, 1\n",
"265, None, 2\n",
"270, None, None\n",
"271, None, None\n",
"272, None, 1\n",
"274, None, 2\n",
"280, None, None\n",
"281, None, None\n",
"282, None, 1\n",
"286, None, 1\n",
"288, None, None\n",
"289, None, None\n",
"290, None, None\n",
"291, None, 1\n",
"294, None, 2\n",
"296, None, None\n",
"297, None, 1\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment