Created
April 5, 2022 00:15
-
-
Save shihono/b40691d53435ca02d90e9ad12f0366fa to your computer and use it in GitHub Desktop.
nltk_lm_examples.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/shihono/b40691d53435ca02d90e9ad12f0366fa/nltk_lm_examples.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "qTf4YvkDMolR" | |
}, | |
"source": [ | |
"## NLTKの言語モデルモジュールを使う\n", | |
"\n", | |
"NLTKのバージョンを3.7にする (google colabの初期状態だと3.2)\n", | |
"\n", | |
"データはtorchtextの[PennTreebank](https://pytorch.org/text/stable/datasets.html#penntreebank) を使う" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "6SSDGbaSM8Fr", | |
"outputId": "9a6f0af0-8e43-44e5-8778-9b47dd29748c" | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"nltk 3.2.5\n" | |
] | |
} | |
], | |
"source": [ | |
"!pip list |grep \"nltk\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "0V2oAEfH1lnJ", | |
"outputId": "95bf2846-8c88-4725-e197-213cf66cb2da" | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Requirement already satisfied: nltk in /usr/local/lib/python3.7/dist-packages (3.2.5)\n", | |
"Collecting nltk\n", | |
" Downloading nltk-3.7-py3-none-any.whl (1.5 MB)\n", | |
"\u001b[K |████████████████████████████████| 1.5 MB 4.3 MB/s \n", | |
"\u001b[?25hRequirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from nltk) (4.63.0)\n", | |
"Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from nltk) (1.1.0)\n", | |
"Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from nltk) (7.1.2)\n", | |
"Collecting regex>=2021.8.3\n", | |
" Downloading regex-2022.3.15-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (749 kB)\n", | |
"\u001b[K |████████████████████████████████| 749 kB 45.5 MB/s \n", | |
"\u001b[?25hInstalling collected packages: regex, nltk\n", | |
" Attempting uninstall: regex\n", | |
" Found existing installation: regex 2019.12.20\n", | |
" Uninstalling regex-2019.12.20:\n", | |
" Successfully uninstalled regex-2019.12.20\n", | |
" Attempting uninstall: nltk\n", | |
" Found existing installation: nltk 3.2.5\n", | |
" Uninstalling nltk-3.2.5:\n", | |
" Successfully uninstalled nltk-3.2.5\n", | |
"Successfully installed nltk-3.7 regex-2022.3.15\n" | |
] | |
} | |
], | |
"source": [ | |
"!pip install -U nltk" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "dh4uzxzEM3Jt" | |
}, | |
"outputs": [], | |
"source": [ | |
"import math\n", | |
"from collections import defaultdict\n", | |
"from itertools import islice\n", | |
"from pprint import pprint \n", | |
"\n", | |
"import matplotlib.pyplot as plt\n", | |
"\n", | |
"# pytorch 関連\n", | |
"# import torch\n", | |
"# import torch.nn as nn\n", | |
"# import torch.nn.functional as F\n", | |
"\n", | |
"# import torchtext\n", | |
"from torchtext.data.utils import get_tokenizer\n", | |
"from torchtext.vocab import build_vocab_from_iterator\n", | |
"from torchtext.datasets import PennTreebank" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "qTfs_eLhanVm" | |
}, | |
"source": [ | |
"### 前処理\n", | |
"\n", | |
"モデルへの入力のために必要なVocabularyとNgramCounterを準備\n", | |
"\n", | |
"- [nltk.lm.vocabulary.Vocabulary](https://www.nltk.org/api/nltk.lm.vocabulary.html#nltk.lm.vocabulary.Vocabulary): vocabulary辞書\n", | |
" - 初期値で word の iteratorもしくはcollections.Counter を渡す\n", | |
"- [nltk.lm.counter.NgramCounter](https://www.nltk.org/api/nltk.lm.counter.html#nltk.lm.counter.NgramCounter): 複数のngramに対応したカウント辞書\n", | |
" - 初期値で ngram_text (`Iterable(Iterable(tuple(str)))`) を渡す\n", | |
"\n", | |
"\n", | |
"sentence から word へ分割する際の処理\n", | |
"\n", | |
"- ngramは [nltk.util.everygrams](https://www.nltk.org/api/nltk.util.html#nltk.util.everygrams) で作成\n", | |
" - 返り値は `List[Tuple[str]]` -> `NgramCounter` に渡せる\n", | |
"\n", | |
"```python\n", | |
"list(everygrams(['a', 'b', 'c']))\n", | |
"[('a',), ('a', 'b'), ('a', 'b', 'c'), ('b',), ('b', 'c'), ('c',)]\n", | |
"```\n", | |
"\n", | |
"- 単語は nltk.lm.preprocessing.pad_both_ends で paddingする\n", | |
"- ngramとpaddingをまとめて実行する [nltk.lm.preprocessing.padded_everygrams](https://www.nltk.org/api/nltk.lm.preprocessing.html#nltk.lm.preprocessing.padded_everygrams) もある\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "pg96Q_-_bE8A", | |
"outputId": "7a270d74-f3f2-4d2f-97ca-3548e17efa75" | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"5.10MB [00:00, 40.0MB/s]" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"['<s>', 'aer', 'banknote', 'berlitz', 'calloway', 'centrust', 'cluett', 'fromstein', 'gitano', 'guterman', 'hydro-quebec', 'ipo', 'kia', 'memotec', 'mlx', 'nahb', 'punts', 'rake', 'regatta', 'rubens', 'sim', 'snack-food', 'ssangyong', 'swapo', 'wachter', '</s>']\n", | |
"[('<s>',), ('<s>', 'aer'), ('aer',), ('aer', 'banknote'), ('banknote',), ('banknote', 'berlitz'), ('berlitz',), ('berlitz', 'calloway'), ('calloway',), ('calloway', 'centrust'), ('centrust',), ('centrust', 'cluett'), ('cluett',), ('cluett', 'fromstein'), ('fromstein',), ('fromstein', 'gitano'), ('gitano',), ('gitano', 'guterman'), ('guterman',), ('guterman', 'hydro-quebec'), ('hydro-quebec',), ('hydro-quebec', 'ipo'), ('ipo',), ('ipo', 'kia'), ('kia',), ('kia', 'memotec'), ('memotec',), ('memotec', 'mlx'), ('mlx',), ('mlx', 'nahb'), ('nahb',), ('nahb', 'punts'), ('punts',), ('punts', 'rake'), ('rake',), ('rake', 'regatta'), ('regatta',), ('regatta', 'rubens'), ('rubens',), ('rubens', 'sim'), ('sim',), ('sim', 'snack-food'), ('snack-food',), ('snack-food', 'ssangyong'), ('ssangyong',), ('ssangyong', 'swapo'), ('swapo',), ('swapo', 'wachter'), ('wachter',), ('wachter', '</s>'), ('</s>',)]\n" | |
] | |
}, | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"from collections import Counter\n", | |
"\n", | |
"import nltk\n", | |
"from nltk.util import ngrams, everygrams\n", | |
"from nltk.lm.vocabulary import Vocabulary\n", | |
"from nltk.lm.counter import NgramCounter\n", | |
"from nltk.lm.preprocessing import padded_everygrams, pad_both_ends\n", | |
"\n", | |
"tokenizer = get_tokenizer(None)\n", | |
"ngram_order = 2\n", | |
"\n", | |
"# 動作確認\n", | |
"train_iter = PennTreebank(split='train')\n", | |
"for data in train_iter:\n", | |
" words = list(pad_both_ends(tokenizer(data),n=ngram_order))\n", | |
" ngrams = everygrams(words, max_len=ngram_order)\n", | |
" print(list(words))\n", | |
" print(list(ngrams))\n", | |
" break" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "NXP8u4g4XGJI", | |
"outputId": "46ce64b6-f2cd-4a96-82e8-56bb7ed4338d" | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1901246\n" | |
] | |
} | |
], | |
"source": [ | |
"counts = Counter()\n", | |
"ngram_text = []\n", | |
"train_iter = PennTreebank(split='train')\n", | |
"for data in train_iter:\n", | |
" words = list(pad_both_ends(tokenizer(data),n=ngram_order))\n", | |
" counts.update(words)\n", | |
" ngrams = everygrams(words, max_len=ngram_order)\n", | |
" ngram_text.append(ngrams)\n", | |
" \n", | |
"vocab = Vocabulary(counts=counts, unk_cutoff=1)\n", | |
"counter = NgramCounter(ngram_text)\n", | |
"print(counter.N())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "7sEOzlm1swCu", | |
"outputId": "8934e06a-bf3a-424d-aaf8-1294f045bc14" | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[('<unk>', 2449),\n", | |
" ('share', 1120),\n", | |
" ('year', 652),\n", | |
" ('N', 538),\n", | |
" ('$', 421),\n", | |
" ('new', 397),\n", | |
" ('few', 187),\n", | |
" ('major', 181),\n", | |
" ('lot', 169),\n", | |
" ('result', 143)]" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# \"a\" に続くword\n", | |
"sorted(counter[('a',)].items(), key=lambda x:x[-1])[::-1][:10]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "k9vPMfF-WArm" | |
}, | |
"source": [ | |
"## Model\n", | |
"\n", | |
"https://www.nltk.org/api/nltk.lm.models.html#module-nltk.lm.models に実装されたモデル一覧がある。\n", | |
"\n", | |
"これらのクラスは [nltk.lm.api.LanguageModel](https://www.nltk.org/api/nltk.lm.api.html#nltk.lm.api.LanguageModel) を継承している\n", | |
"\n", | |
"初期値\n", | |
"- order: ngramのnを指定する\n", | |
"- vocabulary: Vocabulary\n", | |
"- counter: NgramCounter\n", | |
"\n", | |
"fit関数があるが、初期値で vocabulary と counts 渡していれば実行の必要はない (追加データがあれば使う)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "UtilXoAcbYtt", | |
"outputId": "edfa79b6-6725-41d2-d224-90de47f85f28" | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<nltk.lm.models.WittenBellInterpolated at 0x7efd9d4e7cd0>" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from nltk.lm import WittenBellInterpolated\n", | |
"\n", | |
"lm = WittenBellInterpolated(ngram_order, vocabulary=vocab, counter=counter)\n", | |
"lm" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "gdWnUXOQc8hb", | |
"outputId": "c4a64972-77ea-484e-a0a8-49d2e0caa0db" | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"42068" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"lm.estimator.counts[\"<s>\"]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "3tAGx4IWb73s", | |
"outputId": "8902266b-d380-49d9-ac94-924cf2ca3fa2" | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[('<s>', 'this'), ('this', 'is'), ('is', 'a'), ('a', 'sentence'), ('sentence', '</s>')]\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"7.405532740762131" | |
] | |
}, | |
"execution_count": 33, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from nltk.util import bigrams\n", | |
"\n", | |
"sent = \"this is a sentence\"\n", | |
"sent_pad = list(bigrams(pad_both_ends(tokenizer(sent), n=ngram_order)))\n", | |
"print(sent_pad)\n", | |
"lm.entropy(sent_pad)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "k7dT-PuDtqh6" | |
}, | |
"source": [ | |
"### エントロピー\n", | |
"\n", | |
"それぞれのモデルのエントロピーを求める (bigram)\n", | |
"\n", | |
"- テストセット `PennTreebank(split='test')` で評価\n", | |
" - lm.entropy() の引数は `text_ngrams` は (Iterable(tuple(str))) なので chainでつなげる" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "NIJTSKMXt7hM" | |
}, | |
"outputs": [], | |
"source": [ | |
"from itertools import chain\n", | |
"from nltk.util import bigrams\n", | |
"\n", | |
"def get_text_ngrams(split='test'):\n", | |
" return chain(*[\n", | |
" bigrams(pad_both_ends(tokenizer(sent), n=ngram_order))\n", | |
" for sent in PennTreebank(split=split)\n", | |
" ])\n", | |
"\n", | |
"def cal_entropy(model, **kwargs):\n", | |
" lm = model(order=ngram_order, vocabulary=vocab, counter=counter, **kwargs)\n", | |
" return lm.entropy(get_text_ngrams())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "pDKr7mqlwbBc", | |
"outputId": "843028fc-7e15-462c-a636-939806b9a38b" | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[('<s>', 'aer'),\n", | |
" ('aer', 'banknote'),\n", | |
" ('banknote', 'berlitz'),\n", | |
" ('berlitz', 'calloway'),\n", | |
" ('calloway', 'centrust'),\n", | |
" ('centrust', 'cluett'),\n", | |
" ('cluett', 'fromstein'),\n", | |
" ('fromstein', 'gitano'),\n", | |
" ('gitano', 'guterman'),\n", | |
" ('guterman', 'hydro-quebec')]" | |
] | |
}, | |
"execution_count": 37, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"list(get_text_ngrams(\"train\"))[:10]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "0UmbySwFwN3B", | |
"outputId": "aa6c9679-14a7-4788-d1eb-1d0243e8082a" | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"7.662118839369174" | |
] | |
}, | |
"execution_count": 38, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from nltk.lm import WittenBellInterpolated\n", | |
"\n", | |
"cal_entropy(WittenBellInterpolated)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "KqpYBOo4xLkJ", | |
"outputId": "bc78c1f2-4274-4567-d330-d1e777a19593" | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"9.291099950667238" | |
] | |
}, | |
"execution_count": 39, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from nltk.lm import Lidstone, Laplace, StupidBackoff, AbsoluteDiscountingInterpolated, KneserNeyInterpolated\n", | |
"\n", | |
"cal_entropy(Lidstone, gamma=0.5)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"background_save": true, | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "p3I2PeoVzJoW", | |
"outputId": "6707c442-b725-449c-b039-fed4ba5c0e74" | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"<class 'nltk.lm.models.Laplace'>\n", | |
"9.686756185613913\n", | |
"\n", | |
"<class 'nltk.lm.models.StupidBackoff'>\n", | |
"7.416354915097871\n", | |
"\n", | |
"<class 'nltk.lm.models.AbsoluteDiscountingInterpolated'>\n", | |
"7.62192640865039\n", | |
"\n", | |
"<class 'nltk.lm.models.KneserNeyInterpolated'>\n", | |
"7.961439762038167\n", | |
"\n" | |
] | |
} | |
], | |
"source": [ | |
"for m in [Laplace, StupidBackoff, AbsoluteDiscountingInterpolated, KneserNeyInterpolated]:\n", | |
" print(m)\n", | |
" print(cal_entropy(m))\n", | |
" print()" | |
] | |
} | |
], | |
"metadata": { | |
"colab": { | |
"name": "nltk_lm_examples.ipynb", | |
"provenance": [], | |
"authorship_tag": "ABX9TyPX2JAyVHh4Ijj3LZIYmTsb", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"name": "python3" | |
}, | |
"language_info": { | |
"name": "python" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment