Skip to content

Instantly share code, notes, and snippets.

@ShikiOkasaka
Last active May 27, 2024 09:25
Show Gist options
  • Save ShikiOkasaka/c87bdf5bcb996658f579b9a8bb23a6bb to your computer and use it in GitHub Desktop.
Save ShikiOkasaka/c87bdf5bcb996658f579b9a8bb23a6bb to your computer and use it in GitHub Desktop.
Hugging Face Transformersでかな漢字変換の実験
#!/usr/bin/env python3
# pip install transformers
# pip install fugashi
# pip install ipadic
# pip install unidic_lite
import torch
from transformers import BertForMaskedLM
from transformers import BertJapaneseTokenizer
tokenizer = BertJapaneseTokenizer.from_pretrained('cl-tohoku/bert-base-japanese-v3')
model = BertForMaskedLM.from_pretrained('cl-tohoku/bert-base-japanese-v3')
def pick(candidates):
print('Q: ', candidates)
encoded_candidates = tokenizer(candidates)
transposed = list(zip(*encoded_candidates.input_ids))
for mask_token_index, ids in enumerate(transposed):
if len(set(ids)) != 1:
break
ids = encoded_candidates.input_ids[0][:mask_token_index]
ids += (tokenizer.mask_token_id, tokenizer.sep_token_id)
inputs = {
'input_ids': torch.tensor(ids).unsqueeze(0)
}
logits = model(**inputs).logits
token_ids = list(transposed[mask_token_index])
topk = torch.topk(logits[0, mask_token_index][token_ids], k=len(candidates))
print(' ', topk.values.tolist())
print(' ', topk.indices.tolist())
return candidates[topk.indices[0]]
print('A: ', pick(('わたしの生き概論', 'わたしの生きが異論', 'わたしの生きがい論')))
print('A: ', pick(('電車に乗って', '電車に載って')))
print('A: ', pick(('新聞に乗って', '新聞に載って')))
print('A: ', pick(('先生にあって間隙', '先生にあって観劇', '先生にあって感激')))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment