Skip to content

Instantly share code, notes, and snippets.

@tomonari-masada
Created September 10, 2019 09:48
Show Gist options
  • Save tomonari-masada/1ce8bc1cf594b69bc4d3480d734c4aba to your computer and use it in GitHub Desktop.
Save tomonari-masada/1ce8bc1cf594b69bc4d3480d734c4aba to your computer and use it in GitHub Desktop.
from pyknp import Juman
import torch
from pytorch_transformers import *
config = BertConfig.from_json_file('Japanese_L-12_H-768_A-12_E-30_BPE/bert_config.json')
model = BertForMaskedLM.from_pretrained('Japanese_L-12_H-768_A-12_E-30_BPE/pytorch_model.bin',
config=config)
tokenizer = BertTokenizer('Japanese_L-12_H-768_A-12_E-30_BPE/vocab.txt',
do_lower_case=False, do_basic_tokenize=False)
jumanpp = Juman()
text = "僕は友達とサッカーをすることが好きだ。"
result = jumanpp.analysis(text)
tokenized_text = [mrph.midasi for mrph in result.mrph_list()]
tokenized_text = [tokenizer.cls_token] + tokenized_text + [tokenizer.sep_token]
model.eval()
model.to('cuda')
for masked_index in range(1, len(tokenized_text) - 1):
temp_text = [w for w in tokenized_text]
temp_text[masked_index] = tokenizer.mask_token
tokens_tensor = torch.tensor([tokenizer.convert_tokens_to_ids(temp_text)]).to('cuda')
with torch.no_grad():
outputs = model(tokens_tensor)
predictions = outputs[0]
_, predicted_indexes = torch.topk(predictions[0, masked_index], k=5)
predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_indexes.tolist())
print(temp_text)
print(predicted_tokens)
print('-' * 32)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment