Skip to content

Instantly share code, notes, and snippets.

@shihono
Created April 5, 2022 00:15
Show Gist options
  • Save shihono/b40691d53435ca02d90e9ad12f0366fa to your computer and use it in GitHub Desktop.
Save shihono/b40691d53435ca02d90e9ad12f0366fa to your computer and use it in GitHub Desktop.
nltk_lm_examples.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/shihono/b40691d53435ca02d90e9ad12f0366fa/nltk_lm_examples.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qTf4YvkDMolR"
},
"source": [
"## NLTKの言語モデルモジュールを使う\n",
"\n",
"NLTKのバージョンを3.7にする (google colabの初期状態だと3.2)\n",
"\n",
"データはtorchtextの[PennTreebank](https://pytorch.org/text/stable/datasets.html#penntreebank) を使う"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6SSDGbaSM8Fr",
"outputId": "9a6f0af0-8e43-44e5-8778-9b47dd29748c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"nltk 3.2.5\n"
]
}
],
"source": [
"!pip list |grep \"nltk\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0V2oAEfH1lnJ",
"outputId": "95bf2846-8c88-4725-e197-213cf66cb2da"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: nltk in /usr/local/lib/python3.7/dist-packages (3.2.5)\n",
"Collecting nltk\n",
" Downloading nltk-3.7-py3-none-any.whl (1.5 MB)\n",
"\u001b[K |████████████████████████████████| 1.5 MB 4.3 MB/s \n",
"\u001b[?25hRequirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from nltk) (4.63.0)\n",
"Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from nltk) (1.1.0)\n",
"Requirement already satisfied: click in /usr/local/lib/python3.7/dist-packages (from nltk) (7.1.2)\n",
"Collecting regex>=2021.8.3\n",
" Downloading regex-2022.3.15-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (749 kB)\n",
"\u001b[K |████████████████████████████████| 749 kB 45.5 MB/s \n",
"\u001b[?25hInstalling collected packages: regex, nltk\n",
" Attempting uninstall: regex\n",
" Found existing installation: regex 2019.12.20\n",
" Uninstalling regex-2019.12.20:\n",
" Successfully uninstalled regex-2019.12.20\n",
" Attempting uninstall: nltk\n",
" Found existing installation: nltk 3.2.5\n",
" Uninstalling nltk-3.2.5:\n",
" Successfully uninstalled nltk-3.2.5\n",
"Successfully installed nltk-3.7 regex-2022.3.15\n"
]
}
],
"source": [
"!pip install -U nltk"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dh4uzxzEM3Jt"
},
"outputs": [],
"source": [
"import math\n",
"from collections import defaultdict\n",
"from itertools import islice\n",
"from pprint import pprint \n",
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# pytorch 関連\n",
"# import torch\n",
"# import torch.nn as nn\n",
"# import torch.nn.functional as F\n",
"\n",
"# import torchtext\n",
"from torchtext.data.utils import get_tokenizer\n",
"from torchtext.vocab import build_vocab_from_iterator\n",
"from torchtext.datasets import PennTreebank"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qTfs_eLhanVm"
},
"source": [
"### 前処理\n",
"\n",
"モデルへの入力のために必要なVocabularyとNgramCounterを準備\n",
"\n",
"- [nltk.lm.vocabulary.Vocabulary](https://www.nltk.org/api/nltk.lm.vocabulary.html#nltk.lm.vocabulary.Vocabulary): vocabulary辞書\n",
" - 初期値で word の iteratorもしくはcollections.Counter を渡す\n",
"- [nltk.lm.counter.NgramCounter](https://www.nltk.org/api/nltk.lm.counter.html#nltk.lm.counter.NgramCounter): 複数のngramに対応したカウント辞書\n",
" - 初期値で ngram_text (`Iterable(Iterable(tuple(str)))`) を渡す\n",
"\n",
"\n",
"sentence から word へ分割する際の処理\n",
"\n",
"- ngramは [nltk.util.everygrams](https://www.nltk.org/api/nltk.util.html#nltk.util.everygrams) で作成\n",
" - 返り値は `List[Tuple[str]]` -> `NgramCounter` に渡せる\n",
"\n",
"```python\n",
"list(everygrams(['a', 'b', 'c']))\n",
"[('a',), ('a', 'b'), ('a', 'b', 'c'), ('b',), ('b', 'c'), ('c',)]\n",
"```\n",
"\n",
"- 単語は nltk.lm.preprocessing.pad_both_ends で paddingする\n",
"- ngramとpaddingをまとめて実行する [nltk.lm.preprocessing.padded_everygrams](https://www.nltk.org/api/nltk.lm.preprocessing.html#nltk.lm.preprocessing.padded_everygrams) もある\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pg96Q_-_bE8A",
"outputId": "7a270d74-f3f2-4d2f-97ca-3548e17efa75"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"5.10MB [00:00, 40.0MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"['<s>', 'aer', 'banknote', 'berlitz', 'calloway', 'centrust', 'cluett', 'fromstein', 'gitano', 'guterman', 'hydro-quebec', 'ipo', 'kia', 'memotec', 'mlx', 'nahb', 'punts', 'rake', 'regatta', 'rubens', 'sim', 'snack-food', 'ssangyong', 'swapo', 'wachter', '</s>']\n",
"[('<s>',), ('<s>', 'aer'), ('aer',), ('aer', 'banknote'), ('banknote',), ('banknote', 'berlitz'), ('berlitz',), ('berlitz', 'calloway'), ('calloway',), ('calloway', 'centrust'), ('centrust',), ('centrust', 'cluett'), ('cluett',), ('cluett', 'fromstein'), ('fromstein',), ('fromstein', 'gitano'), ('gitano',), ('gitano', 'guterman'), ('guterman',), ('guterman', 'hydro-quebec'), ('hydro-quebec',), ('hydro-quebec', 'ipo'), ('ipo',), ('ipo', 'kia'), ('kia',), ('kia', 'memotec'), ('memotec',), ('memotec', 'mlx'), ('mlx',), ('mlx', 'nahb'), ('nahb',), ('nahb', 'punts'), ('punts',), ('punts', 'rake'), ('rake',), ('rake', 'regatta'), ('regatta',), ('regatta', 'rubens'), ('rubens',), ('rubens', 'sim'), ('sim',), ('sim', 'snack-food'), ('snack-food',), ('snack-food', 'ssangyong'), ('ssangyong',), ('ssangyong', 'swapo'), ('swapo',), ('swapo', 'wachter'), ('wachter',), ('wachter', '</s>'), ('</s>',)]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"from collections import Counter\n",
"\n",
"import nltk\n",
"from nltk.util import ngrams, everygrams\n",
"from nltk.lm.vocabulary import Vocabulary\n",
"from nltk.lm.counter import NgramCounter\n",
"from nltk.lm.preprocessing import padded_everygrams, pad_both_ends\n",
"\n",
"tokenizer = get_tokenizer(None)\n",
"ngram_order = 2\n",
"\n",
"# 動作確認\n",
"train_iter = PennTreebank(split='train')\n",
"for data in train_iter:\n",
" words = list(pad_both_ends(tokenizer(data),n=ngram_order))\n",
" ngrams = everygrams(words, max_len=ngram_order)\n",
" print(list(words))\n",
" print(list(ngrams))\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "NXP8u4g4XGJI",
"outputId": "46ce64b6-f2cd-4a96-82e8-56bb7ed4338d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1901246\n"
]
}
],
"source": [
"counts = Counter()\n",
"ngram_text = []\n",
"train_iter = PennTreebank(split='train')\n",
"for data in train_iter:\n",
" words = list(pad_both_ends(tokenizer(data),n=ngram_order))\n",
" counts.update(words)\n",
" ngrams = everygrams(words, max_len=ngram_order)\n",
" ngram_text.append(ngrams)\n",
" \n",
"vocab = Vocabulary(counts=counts, unk_cutoff=1)\n",
"counter = NgramCounter(ngram_text)\n",
"print(counter.N())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "7sEOzlm1swCu",
"outputId": "8934e06a-bf3a-424d-aaf8-1294f045bc14"
},
"outputs": [
{
"data": {
"text/plain": [
"[('<unk>', 2449),\n",
" ('share', 1120),\n",
" ('year', 652),\n",
" ('N', 538),\n",
" ('$', 421),\n",
" ('new', 397),\n",
" ('few', 187),\n",
" ('major', 181),\n",
" ('lot', 169),\n",
" ('result', 143)]"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# \"a\" に続くword\n",
"sorted(counter[('a',)].items(), key=lambda x:x[-1])[::-1][:10]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "k9vPMfF-WArm"
},
"source": [
"## Model\n",
"\n",
"https://www.nltk.org/api/nltk.lm.models.html#module-nltk.lm.models に実装されたモデル一覧がある。\n",
"\n",
"これらのクラスは [nltk.lm.api.LanguageModel](https://www.nltk.org/api/nltk.lm.api.html#nltk.lm.api.LanguageModel) を継承している\n",
"\n",
"初期値\n",
"- order: ngramのnを指定する\n",
"- vocabulary: Vocabulary\n",
"- counter: NgramCounter\n",
"\n",
"fit関数があるが、初期値で vocabulary と counts 渡していれば実行の必要はない (追加データがあれば使う)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "UtilXoAcbYtt",
"outputId": "edfa79b6-6725-41d2-d224-90de47f85f28"
},
"outputs": [
{
"data": {
"text/plain": [
"<nltk.lm.models.WittenBellInterpolated at 0x7efd9d4e7cd0>"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from nltk.lm import WittenBellInterpolated\n",
"\n",
"lm = WittenBellInterpolated(ngram_order, vocabulary=vocab, counter=counter)\n",
"lm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gdWnUXOQc8hb",
"outputId": "c4a64972-77ea-484e-a0a8-49d2e0caa0db"
},
"outputs": [
{
"data": {
"text/plain": [
"42068"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lm.estimator.counts[\"<s>\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3tAGx4IWb73s",
"outputId": "8902266b-d380-49d9-ac94-924cf2ca3fa2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[('<s>', 'this'), ('this', 'is'), ('is', 'a'), ('a', 'sentence'), ('sentence', '</s>')]\n"
]
},
{
"data": {
"text/plain": [
"7.405532740762131"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from nltk.util import bigrams\n",
"\n",
"sent = \"this is a sentence\"\n",
"sent_pad = list(bigrams(pad_both_ends(tokenizer(sent), n=ngram_order)))\n",
"print(sent_pad)\n",
"lm.entropy(sent_pad)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "k7dT-PuDtqh6"
},
"source": [
"### エントロピー\n",
"\n",
"それぞれのモデルのエントロピーを求める (bigram)\n",
"\n",
"- テストセット `PennTreebank(split='test')` で評価\n",
" - lm.entropy() の引数は `text_ngrams` は (Iterable(tuple(str))) なので chainでつなげる"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NIJTSKMXt7hM"
},
"outputs": [],
"source": [
"from itertools import chain\n",
"from nltk.util import bigrams\n",
"\n",
"def get_text_ngrams(split='test'):\n",
" return chain(*[\n",
" bigrams(pad_both_ends(tokenizer(sent), n=ngram_order))\n",
" for sent in PennTreebank(split=split)\n",
" ])\n",
"\n",
"def cal_entropy(model, **kwargs):\n",
" lm = model(order=ngram_order, vocabulary=vocab, counter=counter, **kwargs)\n",
" return lm.entropy(get_text_ngrams())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pDKr7mqlwbBc",
"outputId": "843028fc-7e15-462c-a636-939806b9a38b"
},
"outputs": [
{
"data": {
"text/plain": [
"[('<s>', 'aer'),\n",
" ('aer', 'banknote'),\n",
" ('banknote', 'berlitz'),\n",
" ('berlitz', 'calloway'),\n",
" ('calloway', 'centrust'),\n",
" ('centrust', 'cluett'),\n",
" ('cluett', 'fromstein'),\n",
" ('fromstein', 'gitano'),\n",
" ('gitano', 'guterman'),\n",
" ('guterman', 'hydro-quebec')]"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(get_text_ngrams(\"train\"))[:10]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0UmbySwFwN3B",
"outputId": "aa6c9679-14a7-4788-d1eb-1d0243e8082a"
},
"outputs": [
{
"data": {
"text/plain": [
"7.662118839369174"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from nltk.lm import WittenBellInterpolated\n",
"\n",
"cal_entropy(WittenBellInterpolated)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "KqpYBOo4xLkJ",
"outputId": "bc78c1f2-4274-4567-d330-d1e777a19593"
},
"outputs": [
{
"data": {
"text/plain": [
"9.291099950667238"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from nltk.lm import Lidstone, Laplace, StupidBackoff, AbsoluteDiscountingInterpolated, KneserNeyInterpolated\n",
"\n",
"cal_entropy(Lidstone, gamma=0.5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true,
"base_uri": "https://localhost:8080/"
},
"id": "p3I2PeoVzJoW",
"outputId": "6707c442-b725-449c-b039-fed4ba5c0e74"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'nltk.lm.models.Laplace'>\n",
"9.686756185613913\n",
"\n",
"<class 'nltk.lm.models.StupidBackoff'>\n",
"7.416354915097871\n",
"\n",
"<class 'nltk.lm.models.AbsoluteDiscountingInterpolated'>\n",
"7.62192640865039\n",
"\n",
"<class 'nltk.lm.models.KneserNeyInterpolated'>\n",
"7.961439762038167\n",
"\n"
]
}
],
"source": [
"for m in [Laplace, StupidBackoff, AbsoluteDiscountingInterpolated, KneserNeyInterpolated]:\n",
" print(m)\n",
" print(cal_entropy(m))\n",
" print()"
]
}
],
"metadata": {
"colab": {
"name": "nltk_lm_examples.ipynb",
"provenance": [],
"authorship_tag": "ABX9TyPX2JAyVHh4Ijj3LZIYmTsb",
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment