Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Реализация модели детектора перефразировок с BERT в pytorch transformers
"""
Тренировка модели детектора синонимичности двух фраз с использованием претренированной модели BERT на PyTorch
09.03.2020 Добавлен расчет метрики MRR (mean reciprocal rank)
"""
import pandas as pd
import numpy as np
import random
import tqdm
from sklearn.model_selection import train_test_split
import sklearn.metrics
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import transformers
# Датасет с парами фраз, подготовленный в проекте чатбота.
#dataset_path = r'C:\polygon\chatbot\data\synonymy_dataset.csv'
dataset_path = r'/home/inkoziev/polygon/chatbot/data/synonymy_dataset.csv'
class SynonymyDataset(torch.utils.data.Dataset):
def __init__(self, device, computed_params):
self.device = device
self.computed_params = computed_params
#self.bert_model = computed_params['bert_model']
self.samples = []
self.max_len = 0
self.pad_token_id = self.computed_params['bert_tokenizer'].pad_token_id
def append(self, sample):
self.max_len = max(self.max_len, len(sample[0]), len(sample[1]))
self.samples.append(sample)
def __len__(self):
return len(self.samples)
def pad_tokens(self, tokens):
return tokens + [self.pad_token_id] * (self.max_len - len(tokens))
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
tokens1, tokens2, label = self.samples[idx]
#z1 = torch.tensor(self.pad_tokens(tokens1)).unsqueeze(0).to(self.device)
#z2 = torch.tensor(self.pad_tokens(tokens2)).unsqueeze(0).to(self.device)
z1 = torch.tensor(self.pad_tokens(tokens1))
z2 = torch.tensor(self.pad_tokens(tokens2))
#with torch.no_grad():
#x1 = torch.tensor(tokens1).unsqueeze(0).to('cpu')
#x2 = torch.tensor(tokens2).unsqueeze(0).to('cpu')
#z1 = self.bert_model(x1)[0].sum(dim=-2).to(self.device)
#z2 = self.bert_model(x2)[0].sum(dim=-2).to(self.device)
return z1, z2, torch.FloatTensor([label])
def load_samples(input_path, device, computed_params, max_samples):
df = pd.read_csv(input_path, encoding='utf-8', delimiter='\t', quoting=3)
if df.shape[0] > max_samples:
df = df[:max_samples]
bert_tokenizer = computed_params['bert_tokenizer']
max_len = 0
samples = SynonymyDataset(device, computed_params)
for phrase1, phrase2, label in zip(df['premise'].values, df['question'].values, df['relevance'].values):
tokens1 = bert_tokenizer.encode(phrase1)
tokens2 = bert_tokenizer.encode(phrase2)
samples.append((tokens1, tokens2, label))
max_len = max(max_len, len(tokens1), len(tokens2))
computed_params['max_len'] = max_len
return samples
class SynonymyDetector(nn.Module):
def __init__(self, computed_params):
super(SynonymyDetector, self).__init__()
self.bert_model = computed_params['bert_model']
sent_emb_size = computed_params['sent_emb_size']
#self.fc1 = nn.Linear(sent_emb_size*2, 1)
self.fc1 = nn.Linear(sent_emb_size*2, 20)
self.fc2 = nn.Linear(20, 1)
def forward(self, x1, x2):
#with torch.no_grad():
z1 = self.bert_model(x1)[0].sum(dim=-2)
z2 = self.bert_model(x2)[0].sum(dim=-2)
#merged = torch.cat((z1, z2, torch.abs(z1 - z2)), dim=-1)
#merged = torch.cat((z1, z2, torch.abs(z1 - z2), z1 * z2), dim=-1)
merged = torch.cat((z1, z2), dim=-1)
merged = self.fc1(merged)
#merged = torch.relu(merged)
#output = torch.sigmoid(merged)
merged = torch.relu(merged)
merged = self.fc2(merged)
output = torch.sigmoid(merged)
return output
def train(model, device, batch_generator, optimizer, epoch):
model.train()
total_loss = 0
for x1, x2, y in tqdm.tqdm(batch_generator, desc='epoch {}'.format(epoch), total=len(batch_generator)):
x1 = x1.to(device)
x2 = x2.to(device)
y = y.to(device)
optimizer.zero_grad()
output = model(x1, x2)
#loss = F.nll_loss(output, y)
loss = F.binary_cross_entropy(output, y)
total_loss += loss.item()
loss.backward()
optimizer.step()
avg_train_loss = total_loss / len(batch_generator)
return avg_train_loss
def test(model, device, batch_generator):
model.eval()
test_loss = 0
correct = 0
nb_records = 0
with torch.no_grad():
for x1, x2, y in batch_generator:
x1 = x1.to(device)
x2 = x2.to(device)
y = y.to(device)
output = model(x1, x2)
test_loss += F.binary_cross_entropy(output, y, reduction='sum').item() # sum up batch loss
#pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
pred = output > 0.5
y2 = y.view_as(pred) > 0.5
correct += pred.eq(y2).sum().item()
nb_records += y.shape[0]
test_loss /= nb_records
acc = correct / float(nb_records)
print('\nTest set: Average loss: {:.4f}, Accuracy: {:.0f}\n'.format(test_loss, acc))
return acc
def validate(model, device, batch_generator):
"""Вычисление метрики точности после завершения очередной эпохи обучения модели"""
model.eval()
y_true2 = []
y_pred2 = []
with torch.no_grad():
for x1, x2, y_true in batch_generator:
x1 = x1.to(device)
x2 = x2.to(device)
y_pred = model(x1, x2)
y_true2.append(y_true.cpu())
y_pred2.append(y_pred.cpu() > 0.5)
y_true = np.vstack(y_true2) > 0.5
y_pred = np.vstack(y_pred2).reshape(y_true.shape)
f1 = sklearn.metrics.f1_score(y_true=y_true, y_pred=y_pred)
return f1
def calc_rank_metrics(model, device, val_samples, samples):
model.eval()
if len(val_samples) > 1000:
val_samples = val_samples[:1000]
"""Вычисление метрики MRR после обучения модели"""
phrase2samples_1 = dict() # релевантные сэмплы для предложения
phrase2samples_0 = dict() # нерелевантные сэмплы для предложения
for sample in samples:
phrase1 = np.array_str(np.asarray(sample[0]))
label = sample[2]
if label == 1:
if phrase1 in phrase2samples_1:
phrase2samples_1[phrase1].add(sample)
else:
phrase2samples_1[phrase1] = set([sample])
else:
if phrase1 in phrase2samples_0:
phrase2samples_0[phrase1].add(sample)
else:
phrase2samples_0[phrase1] = set([sample])
rank_num = 0
rank_denom = 0
for val_sample in tqdm.tqdm(val_samples, desc='Calculating MRR', total=len(val_samples)):
if val_sample[2] == 1 and not np.array_equal(val_sample[0], val_sample[1]):
phrase1 = val_sample[0]
# Для предложения phrase1 подберем список сопоставляемых предложений, среди которых
# будет только одно релевантное, а остальные - рандомные нерелевантные
phrases2 = [val_sample[1]]
# Добавим готовые негативные примеры для phrase1
phrase1_key = np.array_str(np.asarray(phrase1))
if phrase1_key in phrase2samples_0:
phrases2.extend(s[1] for s in phrase2samples_0[phrase1_key])
# Добавим еще рандомных негативных сэмплов, чтобы общий размер списка на проверку
# был равен заданному фиксированному.
for s in sorted(samples, key=lambda _: random.random()):
if s not in phrase2samples_1[phrase1_key] and np.array_str(np.asarray(s[1])) != phrase1_key:
phrases2.append(s[1])
if len(phrases2) >= 100:
break
rels = []
# TODO переделать на прогон всех сэмплов одним батчем для эффективности
for isample, phrase2 in enumerate(phrases2):
y = model(phrase1.unsqueeze(0).to(device), phrase2.unsqueeze(0).to(device))[0].item()
rels.append((isample, y))
# Теперь в sample_rel список из номеров пар и величины синонимичности. Первая пара
# соответствует действительно синонимичным фразам, остальные - нерелевантные.
# Отсортируем список по убыванию релевантности и найдем позицию релевантной пары.
rels = sorted(rels, key=lambda z: -z[1])
rank = next(i for i, s
in enumerate(rels, start=1)
if s[0] == 0)
rank_num += 1.0/rank
rank_denom += 1
return rank_num / rank_denom
if __name__ == '__main__':
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print('Loading BERT...')
bert_tokenizer = transformers.BertTokenizer.from_pretrained('bert-base-multilingual-cased', do_lower_case=False)
bert_model = transformers.BertModel.from_pretrained('bert-base-multilingual-cased', )
bert_model.to('cuda')
bert_model.eval()
computed_params = {'sent_emb_size': 768,
'bert_tokenizer': bert_tokenizer,
'bert_model': bert_model
}
print('Loading data from "{}"...'.format(dataset_path))
samples = load_samples(dataset_path, device, computed_params, 100000)
print('{} samples'.format(len(samples)))
print('max_len={}'.format(computed_params['max_len']))
batch_size = 30
epochs = 100 # !!! для отладки сделано 1, вернуть 100 после отладки !!!
train_samples, testval_samples = train_test_split(samples, test_size=0.2)
test_samples, val_samples = train_test_split(testval_samples, test_size=0.5)
train_generator = torch.utils.data.DataLoader(train_samples, batch_size=batch_size)
test_generator = torch.utils.data.DataLoader(test_samples, batch_size=batch_size)
#torch.manual_seed(args.seed)
model = SynonymyDetector(computed_params).to(device)
#model = torch.nn.DataParallel(model0).to(device)
optimizer = optim.Adadelta(model.parameters(), lr=0.01)
#scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
weights_path = "synonymy_model.pt"
best_acc = 0.0
nb_bad_epochs = 0
for epoch in range(1, epochs+1):
train_loss = train(model, device, train_generator, optimizer, epoch)
print('Train loss={}'.format(train_loss))
acc = test(model, device, test_generator)
#scheduler.step()
if acc > best_acc:
print('NEW BEST ACC={}'.format(acc))
best_acc = acc
nb_bad_epochs = 0
print('Saving model to "{}"...'.format(weights_path))
torch.save(model.state_dict(), weights_path)
else:
nb_bad_epochs += 1
print('No improvement over current best_acc={}'.format(best_acc))
if nb_bad_epochs >= 10:
print('Early stopping on epoch={} best_acc={}'.format(epoch, best_acc))
break
model.load_state_dict(torch.load(weights_path))
val_generator = torch.utils.data.DataLoader(val_samples, batch_size=batch_size)
f1 = validate(model, device, val_generator)
print('Final f1={}'.format(f1))
# расчет метрики MRR для сопоставления с другими подходами решения задачи.
mrr = calc_rank_metrics(model, device, val_samples, samples)
print('Final MRR={}'.format(mrr))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment