Skip to content

Instantly share code, notes, and snippets.

@shihono
Created February 14, 2022 12:17
Show Gist options
  • Save shihono/6124b364420e80cb8ffbb811e8739b86 to your computer and use it in GitHub Desktop.
Save shihono/6124b364420e80cb8ffbb811e8739b86 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "torchtext_dataset.ipynb",
"provenance": [],
"toc_visible": true,
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"source": [
"!pip list |grep \"torch\""
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "b7Cjvb8FLHc6",
"outputId": "726d3b15-14f4-4824-e69a-d884744264f7"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"torch 1.10.0+cu111\n",
"torchaudio 0.10.0+cu111\n",
"torchsummary 1.5.1\n",
"torchtext 0.11.0\n",
"torchvision 0.11.1+cu111\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "WSBtea6YUk2T"
},
"outputs": [],
"source": [
"import math\n",
"from collections import defaultdict\n",
"from itertools import islice\n",
"from pprint import pprint \n",
"\n",
"# pytorch 関連\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"import torchtext\n",
"from torchtext.data.utils import get_tokenizer\n",
"from torchtext.vocab import build_vocab_from_iterator"
]
},
{
"cell_type": "markdown",
"source": [
"## 前処理: tokenizer と vocab 設定\n",
"\n",
"データセットは\n",
"[torchtext.datasets](https://torchtext.readthedocs.io/en/latest/datasets.html) から取得。\n",
"- iteration するとテキストが返却される\n",
"- 参考 https://pytorch.org/tutorials/beginner/transformer_tutorial.html#load-and-batch-data\n",
"\n",
"前処理\n",
"\n",
"- `get_tokenizer` -> tokenizer設定\n",
"- `vocab` -> Vocab クラス `self.vocab` に辞書"
],
"metadata": {
"id": "nsAPoDzLGjmp"
}
},
{
"cell_type": "code",
"source": [
"# train data 確認\n",
"\n",
"train_iter = torchtext.datasets.WikiText2(split='train')\n",
"\n",
"for text in islice(train_iter, 10):\n",
" pprint(text)"
],
"metadata": {
"id": "-ftkbMTzOqm7",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "5df34b32-f40f-48fe-f0b3-ee1aeb8c4624"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"' \\n'\n",
"' = Valkyria Chronicles III = \\n'\n",
"' \\n'\n",
"(' Senjō no Valkyria 3 : <unk> Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . '\n",
" 'Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria '\n",
" 'Chronicles III outside Japan , is a tactical role @-@ playing video game '\n",
" 'developed by Sega and Media.Vision for the PlayStation Portable . Released '\n",
" 'in January 2011 in Japan , it is the third game in the Valkyria series . '\n",
" '<unk> the same fusion of tactical and real @-@ time gameplay as its '\n",
" 'predecessors , the story runs parallel to the first game and follows the \" '\n",
" 'Nameless \" , a penal military unit serving the nation of Gallia during the '\n",
" 'Second Europan War who perform secret black operations and are pitted '\n",
" 'against the Imperial unit \" <unk> Raven \" . \\n')\n",
"(' The game began development in 2010 , carrying over a large portion of the '\n",
" 'work done on Valkyria Chronicles II . While it retained the standard '\n",
" 'features of the series , it also underwent multiple adjustments , such as '\n",
" 'making the game more <unk> for series newcomers . Character designer <unk> '\n",
" 'Honjou and composer Hitoshi Sakimoto both returned from previous entries , '\n",
" 'along with Valkyria Chronicles II director Takeshi Ozawa . A large team of '\n",
" \"writers handled the script . The game 's opening theme was sung by May 'n \"\n",
" '. \\n')\n",
"(' It met with positive sales in Japan , and was praised by both Japanese and '\n",
" 'western critics . After release , it received downloadable content , along '\n",
" 'with an expanded edition in November of that year . It was also adapted into '\n",
" 'manga and an original video animation series . Due to low sales of Valkyria '\n",
" 'Chronicles II , Valkyria Chronicles III was not localized , but a fan '\n",
" \"translation compatible with the game 's expanded edition was released in \"\n",
" '2014 . Media.Vision would return to the franchise with the development of '\n",
" 'Valkyria : Azure Revolution for the PlayStation 4 . \\n')\n",
"' \\n'\n",
"' = = Gameplay = = \\n'\n",
"' \\n'\n",
"(' As with previous <unk> Chronicles games , Valkyria Chronicles III is a '\n",
" 'tactical role @-@ playing game where players take control of a military unit '\n",
" 'and take part in missions against enemy forces . Stories are told through '\n",
" 'comic book @-@ like panels with animated character portraits , with '\n",
" 'characters speaking partially through voiced speech bubbles and partially '\n",
" 'through <unk> text . The player progresses through a series of linear '\n",
" 'missions , gradually unlocked as maps that can be freely <unk> through and '\n",
" 'replayed as they are unlocked . The route to each story location on the map '\n",
" \"varies depending on an individual player 's approach : when one option is \"\n",
" 'selected , the other is sealed off to the player . Outside missions , the '\n",
" 'player characters rest in a camp , where units can be customized and '\n",
" 'character growth occurs . Alongside the main story missions are character '\n",
" '@-@ specific sub missions relating to different squad members . After the '\n",
" \"game 's completion , additional episodes are unlocked , some of them having \"\n",
" 'a higher difficulty than those found in the rest of the game . There are '\n",
" \"also love simulation elements related to the game 's two main <unk> , \"\n",
" 'although they take a very minor role . \\n')\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"tokenizer = get_tokenizer('basic_english')"
],
"metadata": {
"id": "QGnmWWmGVdP1"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"tokenizer(\"You can now install TorchText using pip!\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "SZAxuTGIWCBy",
"outputId": "7f6975c3-2212-4279-9d85-ee2b6b06f001"
},
"execution_count": 5,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"['you', 'can', 'now', 'install', 'torchtext', 'using', 'pip', '!']"
]
},
"metadata": {},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"source": [
"train_iter = torchtext.datasets.WikiText2(split='train')\n",
"vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])\n",
"# 未知語は全部 `<unk>` 扱いにする\n",
"vocab.set_default_index(vocab['<unk>'])\n",
"\n",
"vocab.vocab[\"the\"], vocab.vocab[\"<unk>\"], vocab.vocab[\"<unk\"]"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Mb9TBV69Vr36",
"outputId": "fdd17675-e69b-4de6-cd5c-ccb264bf2f34"
},
"execution_count": 6,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(1, 0, 0)"
]
},
"metadata": {},
"execution_count": 6
}
]
},
{
"cell_type": "markdown",
"source": [
"## 前処理: word to index\n",
"\n",
"`vocab` を使って word を 数値に変換\n",
"\n",
"https://pytorch.org/tutorials/beginner/text_sentiment_ngrams_tutorial.html"
],
"metadata": {
"id": "di1LfvNxjGjv"
}
},
{
"cell_type": "code",
"source": [
"def text2index(text):\n",
" return vocab(tokenizer(text))"
],
"metadata": {
"id": "pvXlEBWxjPOj"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"source": [
"text2index(' = Valkyria Chronicles III = ')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "MyJL1NvckQR5",
"outputId": "81cd8347-b75c-46a8-9f25-033e992baf94"
},
"execution_count": 8,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[9, 3849, 3869, 881, 9]"
]
},
"metadata": {},
"execution_count": 8
}
]
},
{
"cell_type": "code",
"source": [
"v_indexes = vocab.vocab.get_stoi().values()\n",
"max(v_indexes), min(v_indexes)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9BtfzBl2k6oO",
"outputId": "b7b04843-6efb-432b-e866-0451d7084e02"
},
"execution_count": 9,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(28781, 0)"
]
},
"metadata": {},
"execution_count": 9
}
]
},
{
"cell_type": "markdown",
"source": [
"# n-gram データの作成"
],
"metadata": {
"id": "Gn4hz0WZbY9h"
}
},
{
"cell_type": "code",
"source": [
"def iter_ngram(items, n=2):\n",
" for idx in range(len(items)-n+1):\n",
" yield items[idx:idx+n]\n",
"\n",
"list(iter_ngram([9, 3849, 3869, 881, 9], n=3)), list(iter_ngram([9, 3849, 3869, 881, 9], n=1))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "HTSyVabGHbFO",
"outputId": "d315e46b-11ce-4101-9c64-fab8da55194a"
},
"execution_count": 10,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"([[9, 3849, 3869], [3849, 3869, 881], [3869, 881, 9]],\n",
" [[9], [3849], [3869], [881], [9]])"
]
},
"metadata": {},
"execution_count": 10
}
]
},
{
"cell_type": "code",
"source": [
"train_iter = torchtext.datasets.WikiText2(split='train')\n",
"\n",
"# unigram と bigram を取得する\n",
"unigram_counts = defaultdict(int)\n",
"bigram_counts = defaultdict(int)\n",
"\n",
"for t in train_iter:\n",
" t = t.strip()\n",
" if len(t) < 1:\n",
" continue\n",
" tokens = text2index(t)\n",
" for t in iter_ngram(tokens, 1):\n",
" unigram_counts[t[0]] +=1\n",
"\n",
" for t in iter_ngram(tokens, 2):\n",
" bigram_counts[tuple(t)] +=1\n",
"\n",
"\n",
"unigram_counts = dict(unigram_counts)\n",
"bigram_counts = dict(bigram_counts)\n",
"\n",
"len(unigram_counts), len(bigram_counts)"
],
"metadata": {
"id": "OrXCScQfbl66",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "2277af4d-fbcc-4510-e3c2-4f2716d19c34"
},
"execution_count": 11,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(28782, 577049)"
]
},
"metadata": {},
"execution_count": 11
}
]
},
{
"cell_type": "code",
"source": [
"list(islice(unigram_counts.items(),10))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5Gy6cVBPKBfJ",
"outputId": "ba0102b7-fc58-4787-c45d-b9c2326d0f15"
},
"execution_count": 12,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[(9, 29570),\n",
" (3849, 54),\n",
" (3869, 53),\n",
" (881, 231),\n",
" (20000, 5),\n",
" (83, 1702),\n",
" (88, 1601),\n",
" (0, 54625),\n",
" (21, 11992),\n",
" (780, 255)]"
]
},
"metadata": {},
"execution_count": 12
}
]
},
{
"cell_type": "code",
"source": [
"# bigram 確認\n",
"for ngram, count in islice(bigram_counts.items(), 20):\n",
" # prob = count / unigram_counts[ngram[0]]\n",
" word1= vocab.lookup_token(ngram[0])\n",
" word2= vocab.lookup_token(ngram[1])\n",
"\n",
" print(ngram, word1, word2, count)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TQ4lkfqzpNpU",
"outputId": "dbb1bdf0-d9cd-41d3-f3f2-0a8a0fd28677"
},
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"(9, 3849) = valkyria 1\n",
"(3849, 3869) valkyria chronicles 36\n",
"(3869, 881) chronicles iii 15\n",
"(881, 9) iii = 4\n",
"(20000, 83) senjō no 5\n",
"(83, 3849) no valkyria 5\n",
"(3849, 88) valkyria 3 5\n",
"(88, 0) 3 <unk> 21\n",
"(0, 3869) <unk> chronicles 3\n",
"(3869, 21) chronicles ( 3\n",
"(21, 780) ( japanese 11\n",
"(780, 28780) japanese 戦場のヴァルキュリア3 1\n",
"(28780, 2) 戦場のヴァルキュリア3 , 1\n",
"(2, 6182) , lit 10\n",
"(6182, 3) lit . 15\n",
"(3, 3849) . valkyria 4\n",
"(3849, 4) valkyria of 4\n",
"(4, 1) of the 17322\n",
"(1, 5023) the battlefield 31\n",
"(5023, 88) battlefield 3 4\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# bi-gram model\n",
"\n",
"エントロピーを求める。\n",
"\n",
"[An Empirical Study of Smoothing Techniques for Language Modeling](https://arxiv.org/abs/cmp-lg/9606011) のクロスエントロピーの定義に従う。\n",
"\n",
"対数尤度を単語数で割った値\n",
"\n",
"$$ H = \\frac{1}{N_T} \\sum_{i=1}^{l_T} - \\log_2 P_{model}(t_i) $$\n",
"\n",
"ただし\n",
"- テストデータ $ T \\in (t_0, t_1, ..., t_{l_T}) $\n",
"- $ P_{model}(t_i) $ データ `t_i` に対する model の出力\n",
"- テストデータの合計単語数 $ N_T $\n"
],
"metadata": {
"id": "NYKvIxwBg6XV"
}
},
{
"cell_type": "code",
"source": [
"def get_probability(w1, w2, param=0):\n",
" \"\"\" P(w_2 | w_1) を求める\n",
" C(w1, w2) + param / C(w1) + param\n",
" \"\"\"\n",
" return float(bigram_counts.get((w1,w2), 0) + param) / (unigram_counts.get(w1,0) + param)"
],
"metadata": {
"id": "E5bVDwtM4DzK"
},
"execution_count": 14,
"outputs": []
},
{
"cell_type": "code",
"source": [
"text = \"here we are\"\n",
"tokens = text2index(text)\n",
"\n",
"for idx in range(len(tokens)-1):\n",
" print(idx, get_probability(tokens[idx], tokens[idx+1],))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fuVkuXACg_b_",
"outputId": "233894dc-f657-48fc-f0b7-8ecfa332185f"
},
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"0 0.006060606060606061\n",
"1 0.07332293291731669\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"def get_entropy(lambda_param=0):\n",
" # 単語数 N_T\n",
" word_sum = 0\n",
" # 対数尤度\n",
" h = 0\n",
" test_iter = torchtext.datasets.WikiText2(split='test')\n",
" for t in test_iter:\n",
" t = t.strip()\n",
" if len(t) < 1:\n",
" continue\n",
" tokens = text2index(t)\n",
" word_sum += len(tokens)\n",
" for i in range(1, len(tokens)-1):\n",
" p = get_probability(\n",
" tokens[i-1], tokens[i], param=lambda_param\n",
" )\n",
" h += -math.log2(p)\n",
"\n",
" print(f\"word_sum: {word_sum}\\nentropy: {h/word_sum}\")"
],
"metadata": {
"id": "WGrkQ40I_gTb"
},
"execution_count": 16,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# 未知語があると p が0になるためエラーになる\n",
"get_entropy()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 309
},
"id": "Q00H_fFq_ua1",
"outputId": "6c884aa7-02a0-45a6-f8ba-56fe1121ff79"
},
"execution_count": 17,
"outputs": [
{
"output_type": "error",
"ename": "ValueError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-17-67aaa4b80aa8>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# 未知語があると p が0になるためエラーになる\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mget_entropy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-16-7eef11121f72>\u001b[0m in \u001b[0;36mget_entropy\u001b[0;34m(lambda_param)\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0mtokens\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtokens\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparam\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlambda_param\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 16\u001b[0m )\n\u001b[0;32m---> 17\u001b[0;31m \u001b[0mh\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlog2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 18\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"word_sum: {word_sum}\\nentropy: {h/word_sum}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: math domain error"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"### add-one\n",
"\n",
"$$ P(w_{i} |w_{i-1})= \\frac{C(w_{i-1},w_{i})+ \\lambda }{C(w_{i-1})+ \\lambda }$$\n",
"\n",
"において、 $ \\lambda = 1 $"
],
"metadata": {
"id": "VgK1iVWc4W6v"
}
},
{
"cell_type": "code",
"source": [
"get_entropy(1)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "axVzrbNXlnJo",
"outputId": "5a3ce2c6-2a97-43ea-c46d-5eadd4fa3f04"
},
"execution_count": 18,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"word_sum: 241859\n",
"entropy: 6.399829746902186\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"### ELE\n",
"\n",
"$ \\lambda = 1/2 $"
],
"metadata": {
"id": "tSoOYat_41vj"
}
},
{
"cell_type": "code",
"source": [
"get_entropy(0.5)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6Q0Vq_BZT01G",
"outputId": "76b66d9f-bd85-49de-d49c-e834e65240f6"
},
"execution_count": 19,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"word_sum: 241859\n",
"entropy: 6.646151077249645\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment