Last active
November 27, 2019 03:40
-
-
Save kanjirz50/1752aa41ad18a0845e6855e1dd102488 to your computer and use it in GitHub Desktop.
ニュースコーパスから学習されたBERTモデルを動かす。公開されているストックマーク株式会社の森長さまに感謝。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# [大規模日本語ビジネスニュースコーパスを学習したBERT事前学習済(MeCab利用)モデルの紹介](https://qiita.com/mkt3/items/3c1278339ff1bcc0187f)を動かす" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-04-11T00:54:45.075324Z", | |
"start_time": "2019-04-11T00:54:45.069771Z" | |
} | |
}, | |
"source": [ | |
"### 事前準備\n", | |
"[ダウンロードリンク](https://drive.google.com/open?id=1iDlmhGgJ54rkVBtZvgMlgbuNwtFQ50V-) から PyTorch 版のファイルをダウンロードする。\n", | |
"\n", | |
"次の3つのファイルが直下に含まれる `tar.gz` アーカイブを作成する。\n", | |
"- bert_config.json\n", | |
"- pytorch_model.bin\n", | |
"- vocab.txt\n", | |
"\n", | |
"vocab.txtはそのままでも必要。\n", | |
"\n", | |
"PyTorchでBERTを扱うインターフェイスのモジュールをインストールしておく。\n", | |
"`pip install pytorch-pretrained-bert`" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-04-11T01:15:52.274037Z", | |
"start_time": "2019-04-11T01:15:51.727863Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.\n" | |
] | |
} | |
], | |
"source": [ | |
"import collections\n", | |
"import logging\n", | |
"import os\n", | |
"\n", | |
"import torch\n", | |
"import MeCab\n", | |
"from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM, WordpieceTokenizer\n", | |
"from pytorch_pretrained_bert.tokenization import load_vocab" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-04-11T01:15:52.891854Z", | |
"start_time": "2019-04-11T01:15:52.888981Z" | |
}, | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"logging.basicConfig(level=logging.INFO)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-04-09T06:29:11.357348Z", | |
"start_time": "2019-04-09T06:29:11.354945Z" | |
} | |
}, | |
"source": [ | |
"辞書の準備を行う。\n", | |
"\n", | |
"`DICDIR` には MeCab IPADic Neologdを指定する。\n", | |
"\n", | |
"`USERDIC` には 次の項目を記述したCSVファイルを作成し、コンパイルする。\n", | |
"理由はBERTで利用する特別なタグは1単語としたいためである。\n", | |
"\n", | |
"```\n", | |
"[UNK],1285,1285,100,名詞,記号,*,*,*,*,[UNK],[UNK],[UNK]\n", | |
"[SEP],1285,1285,100,名詞,記号,*,*,*,*,[SEP],[SEP],[SEP]\n", | |
"[PAD],1285,1285,100,名詞,記号,*,*,*,*,[PAD],[PAD],[PAD]\n", | |
"[CLS],1285,1285,100,名詞,記号,*,*,*,*,[CLS],[CLS],[CLS]\n", | |
"[MASK],1285,1285,100,名詞,記号,*,*,*,*,[MASK],[MASK],[MASK]\n", | |
"```\n", | |
"\n", | |
"コンパイルはこんな感じ `/usr/local/libexec/mecab/mecab-dict-index -d /path/to/mecab-ipadic-neologd/build/mecab-ipadic-2.7.0-20070801-neologd-20180308 -u user.dic -f utf-8 -t utf-8 作った辞書.csv`" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"BERT のインターフェースを提供している [PyTorch-pretorained-BERT](https://github.com/huggingface/pytorch-pretrained-BERT) を利用する。\n", | |
"日本語ビジネスニュースで学習済みのBERTを利用するためにいくつか修正が必要である。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-04-11T01:15:54.285906Z", | |
"start_time": "2019-04-11T01:15:54.278355Z" | |
}, | |
"code_folding": [ | |
4 | |
], | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# BertTokenizerで利用可能な日本語Tokenizerを用意する\n", | |
"DICDIR = \"/path/to/mecab-ipadic-neologd/build/mecab-ipadic-2.7.0-20070801-neologd-20180308\"\n", | |
"USERDIC = \"/path/to/user.dic\"\n", | |
"\n", | |
"class MeCabBert:\n", | |
" def __init__(self, dicdir=DICDIR, userdic=USERDIC):\n", | |
" self.tagger = MeCab.Tagger(f\"-d {dicdir} -b 100000 -Owakati -u {userdic}\")\n", | |
" self.tagger.parse(\"\")\n", | |
"\n", | |
" def tokenize(self, text):\n", | |
" return self.tagger.parse(text).split()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-04-11T01:15:54.345452Z", | |
"start_time": "2019-04-11T01:15:54.309959Z" | |
}, | |
"code_folding": [ | |
1 | |
], | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# BertTokenizerを継承して、MeCabBertで動くように修正する\n", | |
"class BertMeCabTokenizer(BertTokenizer):\n", | |
" def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,\n", | |
" never_split=(\"[UNK]\", \"[SEP]\", \"[PAD]\", \"[CLS]\", \"[MASK]\")):\n", | |
" \"\"\"Constructs a BertTokenizer.\n", | |
" Args:\n", | |
" vocab_file: Path to a one-wordpiece-per-line vocabulary file\n", | |
" do_lower_case: Whether to lower case the input\n", | |
" Only has an effect when do_wordpiece_only=False\n", | |
" do_basic_tokenize: Whether to do basic tokenization before wordpiece.\n", | |
" max_len: An artificial maximum length to truncate tokenized sequences to;\n", | |
" Effective maximum length is always the minimum of this\n", | |
" value (if specified) and the underlying BERT model's\n", | |
" sequence length.\n", | |
" never_split: List of tokens which will never be split during tokenization.\n", | |
" Only has an effect when do_wordpiece_only=False\n", | |
" \"\"\"\n", | |
" if not os.path.isfile(vocab_file):\n", | |
" raise ValueError(\n", | |
" \"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained \"\n", | |
" \"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`\".format(vocab_file))\n", | |
" self.vocab = load_vocab(vocab_file)\n", | |
" self.ids_to_tokens = collections.OrderedDict(\n", | |
" [(ids, tok) for tok, ids in self.vocab.items()])\n", | |
" self.do_basic_tokenize = do_basic_tokenize\n", | |
" if do_basic_tokenize:\n", | |
" # MeCabBertに変更\n", | |
" self.basic_tokenizer = MeCabBert()\n", | |
" self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)\n", | |
" self.max_len = max_len if max_len is not None else int(1e12)\n", | |
"\n", | |
" def convert_tokens_to_ids(self, tokens):\n", | |
" \"\"\"Converts a sequence of tokens into ids using the vocab.\"\"\"\n", | |
" ids = []\n", | |
" for token in tokens:\n", | |
" # SentencepieceやSubwordを使わないので、未知語は[UNK]のidである1を返す\n", | |
" ids.append(self.vocab.get(token, 1))\n", | |
" if len(ids) > self.max_len:\n", | |
" logger.warning(\n", | |
" \"Token indices sequence length is longer than the specified maximum \"\n", | |
" \" sequence length for this BERT model ({} > {}). Running this\"\n", | |
" \" sequence through BERT will result in indexing errors\".format(len(ids), self.max_len)\n", | |
" )\n", | |
" return ids" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Tokenizer" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-04-11T01:15:55.292660Z", | |
"start_time": "2019-04-11T01:15:55.214018Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:pytorch_pretrained_bert.tokenization:loading vocabulary file ./PyTorchVer/vocab.txt\n" | |
] | |
} | |
], | |
"source": [ | |
"# vocab.txtを指定\n", | |
"tokenizer = BertMeCabTokenizer.from_pretrained('./PyTorchVer/vocab.txt')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-04-11T01:15:55.300598Z", | |
"start_time": "2019-04-11T01:15:55.294956Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"['[CLS]', '日本', 'で', '有名', 'な', '銀行', 'は', 'どこ', 'です', 'か', '?', '[SEP]', 'それ', 'は', 'みずほ銀行', 'です', '。', '[SEP]']\n" | |
] | |
} | |
], | |
"source": [ | |
"text = \"[CLS] 日本で有名な銀行はどこですか? [SEP] それはみずほ銀行です。 [SEP]\"\n", | |
"tokenized_text = tokenizer.tokenize(text)\n", | |
"print(tokenized_text)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-04-11T01:15:55.493018Z", | |
"start_time": "2019-04-11T01:15:55.488225Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"['[CLS]', '日本', 'で', '有名', 'な', '[MASK]', 'は', 'どこ', 'です', 'か', '?', '[SEP]', 'それ', 'は', 'みずほ銀行', 'です', '。', '[SEP]']\n" | |
] | |
} | |
], | |
"source": [ | |
"# 例えば、穴埋めタスクを解く場合、[MASK]に置換する\n", | |
"masked_index = 5\n", | |
"tokenized_text[masked_index] = '[MASK]'\n", | |
"print(tokenized_text)\n", | |
"\n", | |
"assert tokenized_text == ['[CLS]', '日本', 'で', '有名', 'な', '[MASK]', 'は', 'どこ', 'です', 'か', '?', '[SEP]', 'それ', 'は', 'みずほ銀行', 'です', '。', '[SEP]']" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-04-11T01:15:55.744021Z", | |
"start_time": "2019-04-11T01:15:55.741324Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"# TokenをID列に変換\n", | |
"indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-04-11T01:15:55.793482Z", | |
"start_time": "2019-04-11T01:15:55.788153Z" | |
}, | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"SEP_ID = 3\n", | |
"def get_segments_ids(indexed_tokens):\n", | |
" segments_ids = []\n", | |
" i = 0\n", | |
" for indexed_token in indexed_tokens:\n", | |
" segments_ids.append(i)\n", | |
" if indexed_token == SEP_ID:\n", | |
" i = i + 1\n", | |
" return segments_ids" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-04-11T01:15:55.808830Z", | |
"start_time": "2019-04-11T01:15:55.806657Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"segments_ids = get_segments_ids(indexed_tokens)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-04-11T01:15:55.859403Z", | |
"start_time": "2019-04-11T01:15:55.856558Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]\n" | |
] | |
} | |
], | |
"source": [ | |
"print(segments_ids)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-04-11T01:15:56.567868Z", | |
"start_time": "2019-04-11T01:15:56.564388Z" | |
}, | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"tokens_tensor = torch.tensor([indexed_tokens])\n", | |
"segments_tensors = torch.tensor([segments_ids])" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### モデルを読み込んで、層の数を確認" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-04-11T01:16:05.358798Z", | |
"start_time": "2019-04-11T01:15:56.853471Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:pytorch_pretrained_bert.modeling:loading archive file ./pytorchver.tar.gz\n", | |
"INFO:pytorch_pretrained_bert.modeling:extracting archive file ./pytorchver.tar.gz to temp dir /tmp/tmpsnvwexxt\n", | |
"INFO:pytorch_pretrained_bert.modeling:Model config {\n", | |
" \"attention_probs_dropout_prob\": 0.1,\n", | |
" \"hidden_act\": \"gelu\",\n", | |
" \"hidden_dropout_prob\": 0.1,\n", | |
" \"hidden_size\": 768,\n", | |
" \"initializer_range\": 0.02,\n", | |
" \"intermediate_size\": 3072,\n", | |
" \"max_position_embeddings\": 512,\n", | |
" \"num_attention_heads\": 12,\n", | |
" \"num_hidden_layers\": 12,\n", | |
" \"type_vocab_size\": 2,\n", | |
" \"vocab_size\": 32005\n", | |
"}\n", | |
"\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"BertModel(\n", | |
" (embeddings): BertEmbeddings(\n", | |
" (word_embeddings): Embedding(32005, 768)\n", | |
" (position_embeddings): Embedding(512, 768)\n", | |
" (token_type_embeddings): Embedding(2, 768)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (encoder): BertEncoder(\n", | |
" (layer): ModuleList(\n", | |
" (0): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (1): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (2): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (3): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (4): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (5): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (6): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (7): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (8): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (9): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (10): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (11): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" (pooler): BertPooler(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (activation): Tanh()\n", | |
" )\n", | |
")" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model = BertModel.from_pretrained('./pytorchver.tar.gz')\n", | |
"model.eval()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-04-11T01:16:05.428229Z", | |
"start_time": "2019-04-11T01:16:05.360688Z" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"with torch.no_grad():\n", | |
" encoded_layers, _ = model(tokens_tensor, segments_tensors)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-04-11T01:16:05.592271Z", | |
"start_time": "2019-04-11T01:16:05.589357Z" | |
}, | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"assert len(encoded_layers) == 12" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### モデルを読み込んで、穴埋めタスクを解く" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-04-11T01:16:14.059586Z", | |
"start_time": "2019-04-11T01:16:05.671871Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"INFO:pytorch_pretrained_bert.modeling:loading archive file ./pytorchver.tar.gz\n", | |
"INFO:pytorch_pretrained_bert.modeling:extracting archive file ./pytorchver.tar.gz to temp dir /tmp/tmpc5utllq6\n", | |
"INFO:pytorch_pretrained_bert.modeling:Model config {\n", | |
" \"attention_probs_dropout_prob\": 0.1,\n", | |
" \"hidden_act\": \"gelu\",\n", | |
" \"hidden_dropout_prob\": 0.1,\n", | |
" \"hidden_size\": 768,\n", | |
" \"initializer_range\": 0.02,\n", | |
" \"intermediate_size\": 3072,\n", | |
" \"max_position_embeddings\": 512,\n", | |
" \"num_attention_heads\": 12,\n", | |
" \"num_hidden_layers\": 12,\n", | |
" \"type_vocab_size\": 2,\n", | |
" \"vocab_size\": 32005\n", | |
"}\n", | |
"\n", | |
"INFO:pytorch_pretrained_bert.modeling:Weights from pretrained model not used in BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"BertForMaskedLM(\n", | |
" (bert): BertModel(\n", | |
" (embeddings): BertEmbeddings(\n", | |
" (word_embeddings): Embedding(32005, 768)\n", | |
" (position_embeddings): Embedding(512, 768)\n", | |
" (token_type_embeddings): Embedding(2, 768)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (encoder): BertEncoder(\n", | |
" (layer): ModuleList(\n", | |
" (0): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (1): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (2): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (3): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (4): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (5): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (6): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (7): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (8): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (9): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (10): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (11): BertLayer(\n", | |
" (attention): BertAttention(\n", | |
" (self): BertSelfAttention(\n", | |
" (query): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (key): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (value): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" (output): BertSelfOutput(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" (intermediate): BertIntermediate(\n", | |
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n", | |
" )\n", | |
" (output): BertOutput(\n", | |
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" (dropout): Dropout(p=0.1)\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
" (pooler): BertPooler(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (activation): Tanh()\n", | |
" )\n", | |
" )\n", | |
" (cls): BertOnlyMLMHead(\n", | |
" (predictions): BertLMPredictionHead(\n", | |
" (transform): BertPredictionHeadTransform(\n", | |
" (dense): Linear(in_features=768, out_features=768, bias=True)\n", | |
" (LayerNorm): BertLayerNorm()\n", | |
" )\n", | |
" (decoder): Linear(in_features=768, out_features=32005, bias=False)\n", | |
" )\n", | |
" )\n", | |
")" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model = BertForMaskedLM.from_pretrained('./pytorchver.tar.gz')\n", | |
"model.eval()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-04-11T01:16:14.127765Z", | |
"start_time": "2019-04-11T01:16:14.063117Z" | |
}, | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"with torch.no_grad():\n", | |
" predictions = model(tokens_tensor, segments_tensors)\n", | |
"\n", | |
"# confirm we were able to predict 'henson'\n", | |
"predicted_index = torch.argmax(predictions[0, masked_index]).item()\n", | |
"predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-04-11T01:16:14.133900Z", | |
"start_time": "2019-04-11T01:16:14.130762Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"銀行\n" | |
] | |
} | |
], | |
"source": [ | |
"# 「銀行」をどんぴしゃで当ててくる\n", | |
"print(predicted_token)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-04-11T01:16:14.154354Z", | |
"start_time": "2019-04-11T01:16:14.150858Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"[UNK]\n" | |
] | |
} | |
], | |
"source": [ | |
"print(tokenizer.convert_ids_to_tokens([1])[0])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": { | |
"ExecuteTime": { | |
"end_time": "2019-04-11T01:16:14.180584Z", | |
"start_time": "2019-04-11T01:16:14.170528Z" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1 銀行\n", | |
"2 企業\n", | |
"3 [UNK]\n", | |
"4 金融機関\n", | |
"5 地方銀行\n", | |
"6 ところ\n", | |
"7 会社\n", | |
"8 地銀\n", | |
"9 国\n", | |
"10 の\n" | |
] | |
} | |
], | |
"source": [ | |
"# 確率上位TOP10の語を出力する\n", | |
"topn = 10\n", | |
"for i, idx in enumerate(torch.argsort(predictions[0, masked_index], descending=True)[:topn], start=1):\n", | |
" print(i, tokenizer.convert_ids_to_tokens([int(idx)])[0])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"hide_input": false, | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.3" | |
}, | |
"toc": { | |
"nav_menu": {}, | |
"number_sections": true, | |
"sideBar": true, | |
"skip_h1_title": false, | |
"toc_cell": false, | |
"toc_position": {}, | |
"toc_section_display": "block", | |
"toc_window_display": false | |
}, | |
"varInspector": { | |
"cols": { | |
"lenName": 16, | |
"lenType": 16, | |
"lenVar": 40 | |
}, | |
"kernels_config": { | |
"python": { | |
"delete_cmd_postfix": "", | |
"delete_cmd_prefix": "del ", | |
"library": "var_list.py", | |
"varRefreshCmd": "print(var_dic_list())" | |
}, | |
"r": { | |
"delete_cmd_postfix": ") ", | |
"delete_cmd_prefix": "rm(", | |
"library": "var_list.r", | |
"varRefreshCmd": "cat(var_dic_list()) " | |
} | |
}, | |
"types_to_exclude": [ | |
"module", | |
"function", | |
"builtin_function_or_method", | |
"instance", | |
"_Feature" | |
], | |
"window_display": false | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment