Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
Оценка вариантов подстановки прямого дополнения в клаузу SVO с помощью GPT
import io
import itertools
import pickle
import collections
import glob
import os
import tqdm
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import numpy as np
import rutokenizer
def score(tokens_tensor):
y = gpt_model(tokens_tensor, labels=tokens_tensor)
loss = y[0]
return np.exp(loss.cpu().detach().numpy())
def normalize_word(word):
return word.replace(' - ', '-').replace(u'ё', u'е').lower()
def pad(items, size, pad):
l = len(items)
if l < size:
return items + [pad] * (size - l)
return items
gren_path = '/home/inkoziev/polygon/chatbot/data/dict/word2tags.dat'
model_name = 'sberbank-ai/rugpt3large_based_on_gpt2'
#model_name = 'sberbank-ai/rugpt3medium_based_on_gpt2'
#model_name = 'gpt2'
gpt_model = GPT2LMHeadModel.from_pretrained(model_name)
gpt_tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer = rutokenizer.Tokenizer()
word2tags_pkl_path = 'word2tags.pkl'
if not os.path.exists(word2tags_pkl_path):
# Соберем список слов, которые встречаются в корпусах
fnames = []
#dir1 = '/home/inkoziev/polygon/chatbot/data/SENTx'
#for filename in glob.iglob(dir1 + '/*.txt'):
# fnames.append(os.path.join(dir1, filename))
dir2 = '/home/inkoziev/polygon/chatbot/data'
for filename in ['facts5.txt', 'facts6.txt', 'facts7.txt', 'facts8.txt']:
fnames.append(os.path.join(dir2, filename))
sents = set()
for i, p in enumerate(fnames, start=1):
print('Loading {}/{} file="{}"...'.format(i, len(fnames), p))
with, 'r', encoding='utf-8') as rdr:
for line in rdr:
known_words = collections.Counter()
for sent in tqdm.tqdm(sents, desc='Токенизация', total=len(sents)):
for word in tokenizer.tokenize(sent.lower()):
known_words[word] += 1
known_words = set(w for w, c in known_words.items() if c > 2)
word2tags = collections.defaultdict(list)
print('Loading grammar dictionary from {}...'.format(gren_path))
with, 'r', encoding='utf-8') as rdr:
for line in rdr:
tx = line.strip().split('\t')
if len(tx) == 5:
word = normalize_word(tx[0])
if word in known_words:
pos = tx[1]
tags = tx[3].split(' ')
tags = set(itertools.chain([pos], tags))
print('{} words in grammar dictionary'.format(len(word2tags)))
with open(word2tags_pkl_path, 'wb') as f:
pickle.dump(word2tags, f)
with open(word2tags_pkl_path, 'rb') as f:
word2tags = pickle.load(f)
probe_src = 'кошки любят есть '
probes = []
for word, tagsets in word2tags.items():
for tags in tagsets:
if 'СУЩЕСТВИТЕЛЬНОЕ' in tags and 'ПАДЕЖ:ВИН' in tags:
probes.append(probe_src + word)
batch_size = 100
probe_scores = []
for text in tqdm.tqdm(probes, desc='Scoring', total=len(probes)):
tokens_tensor = gpt_tokenizer.encode(text, add_special_tokens=False, return_tensors="pt")
p = score(tokens_tensor)
probe_scores.append((text, p))
probe_scores = sorted(probe_scores, key=lambda z: z[1])
for text, score in probe_scores[:10]:
print('{}\t{}'.format(score, text))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment