Skip to content

Instantly share code, notes, and snippets.

@kanjirz50
Last active November 27, 2019 03:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kanjirz50/1752aa41ad18a0845e6855e1dd102488 to your computer and use it in GitHub Desktop.
Save kanjirz50/1752aa41ad18a0845e6855e1dd102488 to your computer and use it in GitHub Desktop.
ニュースコーパスから学習されたBERTモデルを動かす。公開されているストックマーク株式会社の森長さまに感謝。
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# [大規模日本語ビジネスニュースコーパスを学習したBERT事前学習済(MeCab利用)モデルの紹介](https://qiita.com/mkt3/items/3c1278339ff1bcc0187f)を動かす"
]
},
{
"cell_type": "markdown",
"metadata": {
"ExecuteTime": {
"end_time": "2019-04-11T00:54:45.075324Z",
"start_time": "2019-04-11T00:54:45.069771Z"
}
},
"source": [
"### 事前準備\n",
"[ダウンロードリンク](https://drive.google.com/open?id=1iDlmhGgJ54rkVBtZvgMlgbuNwtFQ50V-) から PyTorch 版のファイルをダウンロードする。\n",
"\n",
"次の3つのファイルが直下に含まれる `tar.gz` アーカイブを作成する。\n",
"- bert_config.json\n",
"- pytorch_model.bin\n",
"- vocab.txt\n",
"\n",
"vocab.txtはそのままでも必要。\n",
"\n",
"PyTorchでBERTを扱うインターフェイスのモジュールをインストールしておく。\n",
"`pip install pytorch-pretrained-bert`"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2019-04-11T01:15:52.274037Z",
"start_time": "2019-04-11T01:15:51.727863Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.\n"
]
}
],
"source": [
"import collections\n",
"import logging\n",
"import os\n",
"\n",
"import torch\n",
"import MeCab\n",
"from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM, WordpieceTokenizer\n",
"from pytorch_pretrained_bert.tokenization import load_vocab"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2019-04-11T01:15:52.891854Z",
"start_time": "2019-04-11T01:15:52.888981Z"
},
"collapsed": true
},
"outputs": [],
"source": [
"logging.basicConfig(level=logging.INFO)"
]
},
{
"cell_type": "markdown",
"metadata": {
"ExecuteTime": {
"end_time": "2019-04-09T06:29:11.357348Z",
"start_time": "2019-04-09T06:29:11.354945Z"
}
},
"source": [
"辞書の準備を行う。\n",
"\n",
"`DICDIR` には MeCab IPADic Neologdを指定する。\n",
"\n",
"`USERDIC` には 次の項目を記述したCSVファイルを作成し、コンパイルする。\n",
"理由はBERTで利用する特別なタグは1単語としたいためである。\n",
"\n",
"```\n",
"[UNK],1285,1285,100,名詞,記号,*,*,*,*,[UNK],[UNK],[UNK]\n",
"[SEP],1285,1285,100,名詞,記号,*,*,*,*,[SEP],[SEP],[SEP]\n",
"[PAD],1285,1285,100,名詞,記号,*,*,*,*,[PAD],[PAD],[PAD]\n",
"[CLS],1285,1285,100,名詞,記号,*,*,*,*,[CLS],[CLS],[CLS]\n",
"[MASK],1285,1285,100,名詞,記号,*,*,*,*,[MASK],[MASK],[MASK]\n",
"```\n",
"\n",
"コンパイルはこんな感じ `/usr/local/libexec/mecab/mecab-dict-index -d /path/to/mecab-ipadic-neologd/build/mecab-ipadic-2.7.0-20070801-neologd-20180308 -u user.dic -f utf-8 -t utf-8 作った辞書.csv`"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"BERT のインターフェースを提供している [PyTorch-pretorained-BERT](https://github.com/huggingface/pytorch-pretrained-BERT) を利用する。\n",
"日本語ビジネスニュースで学習済みのBERTを利用するためにいくつか修正が必要である。"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2019-04-11T01:15:54.285906Z",
"start_time": "2019-04-11T01:15:54.278355Z"
},
"code_folding": [
4
],
"collapsed": true
},
"outputs": [],
"source": [
"# BertTokenizerで利用可能な日本語Tokenizerを用意する\n",
"DICDIR = \"/path/to/mecab-ipadic-neologd/build/mecab-ipadic-2.7.0-20070801-neologd-20180308\"\n",
"USERDIC = \"/path/to/user.dic\"\n",
"\n",
"class MeCabBert:\n",
" def __init__(self, dicdir=DICDIR, userdic=USERDIC):\n",
" self.tagger = MeCab.Tagger(f\"-d {dicdir} -b 100000 -Owakati -u {userdic}\")\n",
" self.tagger.parse(\"\")\n",
"\n",
" def tokenize(self, text):\n",
" return self.tagger.parse(text).split()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2019-04-11T01:15:54.345452Z",
"start_time": "2019-04-11T01:15:54.309959Z"
},
"code_folding": [
1
],
"collapsed": true
},
"outputs": [],
"source": [
"# BertTokenizerを継承して、MeCabBertで動くように修正する\n",
"class BertMeCabTokenizer(BertTokenizer):\n",
" def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,\n",
" never_split=(\"[UNK]\", \"[SEP]\", \"[PAD]\", \"[CLS]\", \"[MASK]\")):\n",
" \"\"\"Constructs a BertTokenizer.\n",
" Args:\n",
" vocab_file: Path to a one-wordpiece-per-line vocabulary file\n",
" do_lower_case: Whether to lower case the input\n",
" Only has an effect when do_wordpiece_only=False\n",
" do_basic_tokenize: Whether to do basic tokenization before wordpiece.\n",
" max_len: An artificial maximum length to truncate tokenized sequences to;\n",
" Effective maximum length is always the minimum of this\n",
" value (if specified) and the underlying BERT model's\n",
" sequence length.\n",
" never_split: List of tokens which will never be split during tokenization.\n",
" Only has an effect when do_wordpiece_only=False\n",
" \"\"\"\n",
" if not os.path.isfile(vocab_file):\n",
" raise ValueError(\n",
" \"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained \"\n",
" \"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`\".format(vocab_file))\n",
" self.vocab = load_vocab(vocab_file)\n",
" self.ids_to_tokens = collections.OrderedDict(\n",
" [(ids, tok) for tok, ids in self.vocab.items()])\n",
" self.do_basic_tokenize = do_basic_tokenize\n",
" if do_basic_tokenize:\n",
" # MeCabBertに変更\n",
" self.basic_tokenizer = MeCabBert()\n",
" self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)\n",
" self.max_len = max_len if max_len is not None else int(1e12)\n",
"\n",
" def convert_tokens_to_ids(self, tokens):\n",
" \"\"\"Converts a sequence of tokens into ids using the vocab.\"\"\"\n",
" ids = []\n",
" for token in tokens:\n",
" # SentencepieceやSubwordを使わないので、未知語は[UNK]のidである1を返す\n",
" ids.append(self.vocab.get(token, 1))\n",
" if len(ids) > self.max_len:\n",
" logger.warning(\n",
" \"Token indices sequence length is longer than the specified maximum \"\n",
" \" sequence length for this BERT model ({} > {}). Running this\"\n",
" \" sequence through BERT will result in indexing errors\".format(len(ids), self.max_len)\n",
" )\n",
" return ids"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Tokenizer"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2019-04-11T01:15:55.292660Z",
"start_time": "2019-04-11T01:15:55.214018Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:pytorch_pretrained_bert.tokenization:loading vocabulary file ./PyTorchVer/vocab.txt\n"
]
}
],
"source": [
"# vocab.txtを指定\n",
"tokenizer = BertMeCabTokenizer.from_pretrained('./PyTorchVer/vocab.txt')"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2019-04-11T01:15:55.300598Z",
"start_time": "2019-04-11T01:15:55.294956Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['[CLS]', '日本', 'で', '有名', 'な', '銀行', 'は', 'どこ', 'です', 'か', '?', '[SEP]', 'それ', 'は', 'みずほ銀行', 'です', '。', '[SEP]']\n"
]
}
],
"source": [
"text = \"[CLS] 日本で有名な銀行はどこですか? [SEP] それはみずほ銀行です。 [SEP]\"\n",
"tokenized_text = tokenizer.tokenize(text)\n",
"print(tokenized_text)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2019-04-11T01:15:55.493018Z",
"start_time": "2019-04-11T01:15:55.488225Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['[CLS]', '日本', 'で', '有名', 'な', '[MASK]', 'は', 'どこ', 'です', 'か', '?', '[SEP]', 'それ', 'は', 'みずほ銀行', 'です', '。', '[SEP]']\n"
]
}
],
"source": [
"# 例えば、穴埋めタスクを解く場合、[MASK]に置換する\n",
"masked_index = 5\n",
"tokenized_text[masked_index] = '[MASK]'\n",
"print(tokenized_text)\n",
"\n",
"assert tokenized_text == ['[CLS]', '日本', 'で', '有名', 'な', '[MASK]', 'は', 'どこ', 'です', 'か', '?', '[SEP]', 'それ', 'は', 'みずほ銀行', 'です', '。', '[SEP]']"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2019-04-11T01:15:55.744021Z",
"start_time": "2019-04-11T01:15:55.741324Z"
}
},
"outputs": [],
"source": [
"# TokenをID列に変換\n",
"indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2019-04-11T01:15:55.793482Z",
"start_time": "2019-04-11T01:15:55.788153Z"
},
"collapsed": true
},
"outputs": [],
"source": [
"SEP_ID = 3\n",
"def get_segments_ids(indexed_tokens):\n",
" segments_ids = []\n",
" i = 0\n",
" for indexed_token in indexed_tokens:\n",
" segments_ids.append(i)\n",
" if indexed_token == SEP_ID:\n",
" i = i + 1\n",
" return segments_ids"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2019-04-11T01:15:55.808830Z",
"start_time": "2019-04-11T01:15:55.806657Z"
}
},
"outputs": [],
"source": [
"segments_ids = get_segments_ids(indexed_tokens)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2019-04-11T01:15:55.859403Z",
"start_time": "2019-04-11T01:15:55.856558Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]\n"
]
}
],
"source": [
"print(segments_ids)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2019-04-11T01:15:56.567868Z",
"start_time": "2019-04-11T01:15:56.564388Z"
},
"collapsed": true
},
"outputs": [],
"source": [
"tokens_tensor = torch.tensor([indexed_tokens])\n",
"segments_tensors = torch.tensor([segments_ids])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### モデルを読み込んで、層の数を確認"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"ExecuteTime": {
"end_time": "2019-04-11T01:16:05.358798Z",
"start_time": "2019-04-11T01:15:56.853471Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:pytorch_pretrained_bert.modeling:loading archive file ./pytorchver.tar.gz\n",
"INFO:pytorch_pretrained_bert.modeling:extracting archive file ./pytorchver.tar.gz to temp dir /tmp/tmpsnvwexxt\n",
"INFO:pytorch_pretrained_bert.modeling:Model config {\n",
" \"attention_probs_dropout_prob\": 0.1,\n",
" \"hidden_act\": \"gelu\",\n",
" \"hidden_dropout_prob\": 0.1,\n",
" \"hidden_size\": 768,\n",
" \"initializer_range\": 0.02,\n",
" \"intermediate_size\": 3072,\n",
" \"max_position_embeddings\": 512,\n",
" \"num_attention_heads\": 12,\n",
" \"num_hidden_layers\": 12,\n",
" \"type_vocab_size\": 2,\n",
" \"vocab_size\": 32005\n",
"}\n",
"\n"
]
},
{
"data": {
"text/plain": [
"BertModel(\n",
" (embeddings): BertEmbeddings(\n",
" (word_embeddings): Embedding(32005, 768)\n",
" (position_embeddings): Embedding(512, 768)\n",
" (token_type_embeddings): Embedding(2, 768)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (encoder): BertEncoder(\n",
" (layer): ModuleList(\n",
" (0): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (1): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (2): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (3): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (4): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (5): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (6): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (7): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (8): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (9): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (10): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (11): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (pooler): BertPooler(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (activation): Tanh()\n",
" )\n",
")"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = BertModel.from_pretrained('./pytorchver.tar.gz')\n",
"model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"ExecuteTime": {
"end_time": "2019-04-11T01:16:05.428229Z",
"start_time": "2019-04-11T01:16:05.360688Z"
}
},
"outputs": [],
"source": [
"with torch.no_grad():\n",
" encoded_layers, _ = model(tokens_tensor, segments_tensors)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"ExecuteTime": {
"end_time": "2019-04-11T01:16:05.592271Z",
"start_time": "2019-04-11T01:16:05.589357Z"
},
"collapsed": true
},
"outputs": [],
"source": [
"assert len(encoded_layers) == 12"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### モデルを読み込んで、穴埋めタスクを解く"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"ExecuteTime": {
"end_time": "2019-04-11T01:16:14.059586Z",
"start_time": "2019-04-11T01:16:05.671871Z"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:pytorch_pretrained_bert.modeling:loading archive file ./pytorchver.tar.gz\n",
"INFO:pytorch_pretrained_bert.modeling:extracting archive file ./pytorchver.tar.gz to temp dir /tmp/tmpc5utllq6\n",
"INFO:pytorch_pretrained_bert.modeling:Model config {\n",
" \"attention_probs_dropout_prob\": 0.1,\n",
" \"hidden_act\": \"gelu\",\n",
" \"hidden_dropout_prob\": 0.1,\n",
" \"hidden_size\": 768,\n",
" \"initializer_range\": 0.02,\n",
" \"intermediate_size\": 3072,\n",
" \"max_position_embeddings\": 512,\n",
" \"num_attention_heads\": 12,\n",
" \"num_hidden_layers\": 12,\n",
" \"type_vocab_size\": 2,\n",
" \"vocab_size\": 32005\n",
"}\n",
"\n",
"INFO:pytorch_pretrained_bert.modeling:Weights from pretrained model not used in BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n"
]
},
{
"data": {
"text/plain": [
"BertForMaskedLM(\n",
" (bert): BertModel(\n",
" (embeddings): BertEmbeddings(\n",
" (word_embeddings): Embedding(32005, 768)\n",
" (position_embeddings): Embedding(512, 768)\n",
" (token_type_embeddings): Embedding(2, 768)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (encoder): BertEncoder(\n",
" (layer): ModuleList(\n",
" (0): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (1): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (2): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (3): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (4): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (5): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (6): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (7): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (8): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (9): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (10): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (11): BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" (dropout): Dropout(p=0.1)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (pooler): BertPooler(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (activation): Tanh()\n",
" )\n",
" )\n",
" (cls): BertOnlyMLMHead(\n",
" (predictions): BertLMPredictionHead(\n",
" (transform): BertPredictionHeadTransform(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): BertLayerNorm()\n",
" )\n",
" (decoder): Linear(in_features=768, out_features=32005, bias=False)\n",
" )\n",
" )\n",
")"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = BertForMaskedLM.from_pretrained('./pytorchver.tar.gz')\n",
"model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"ExecuteTime": {
"end_time": "2019-04-11T01:16:14.127765Z",
"start_time": "2019-04-11T01:16:14.063117Z"
},
"collapsed": true
},
"outputs": [],
"source": [
"with torch.no_grad():\n",
" predictions = model(tokens_tensor, segments_tensors)\n",
"\n",
"# confirm we were able to predict 'henson'\n",
"predicted_index = torch.argmax(predictions[0, masked_index]).item()\n",
"predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"ExecuteTime": {
"end_time": "2019-04-11T01:16:14.133900Z",
"start_time": "2019-04-11T01:16:14.130762Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"銀行\n"
]
}
],
"source": [
"# 「銀行」をどんぴしゃで当ててくる\n",
"print(predicted_token)"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"ExecuteTime": {
"end_time": "2019-04-11T01:16:14.154354Z",
"start_time": "2019-04-11T01:16:14.150858Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[UNK]\n"
]
}
],
"source": [
"print(tokenizer.convert_ids_to_tokens([1])[0])"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"ExecuteTime": {
"end_time": "2019-04-11T01:16:14.180584Z",
"start_time": "2019-04-11T01:16:14.170528Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1 銀行\n",
"2 企業\n",
"3 [UNK]\n",
"4 金融機関\n",
"5 地方銀行\n",
"6 ところ\n",
"7 会社\n",
"8 地銀\n",
"9 国\n",
"10 の\n"
]
}
],
"source": [
"# 確率上位TOP10の語を出力する\n",
"topn = 10\n",
"for i, idx in enumerate(torch.argsort(predictions[0, masked_index], descending=True)[:topn], start=1):\n",
" print(i, tokenizer.convert_ids_to_tokens([int(idx)])[0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"hide_input": false,
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
},
"toc": {
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"toc_cell": false,
"toc_position": {},
"toc_section_display": "block",
"toc_window_display": false
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
},
"kernels_config": {
"python": {
"delete_cmd_postfix": "",
"delete_cmd_prefix": "del ",
"library": "var_list.py",
"varRefreshCmd": "print(var_dic_list())"
},
"r": {
"delete_cmd_postfix": ") ",
"delete_cmd_prefix": "rm(",
"library": "var_list.r",
"varRefreshCmd": "cat(var_dic_list()) "
}
},
"types_to_exclude": [
"module",
"function",
"builtin_function_or_method",
"instance",
"_Feature"
],
"window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment