Created
February 22, 2020 15:42
-
-
Save Koziev/3d57b81709b4ee0260492d34f8a4ad52 to your computer and use it in GitHub Desktop.
Модель детектора синонимичности двух фраз: сиамская рекуррентная сетка на PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" Тренировка модели детектора синонимичности двух фраз (сиамская рекуррентная сетка) на PyTorch """ | |
import io | |
import pandas as pd | |
import numpy as np | |
import itertools | |
import random | |
import tqdm | |
from sklearn.model_selection import train_test_split | |
import sklearn.metrics | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import torch.utils.data | |
#dataset_path = r'C:\polygon\chatbot\data\synonymy_dataset.csv' | |
dataset_path = r'/home/inkoziev/polygon/chatbot/data/synonymy_dataset.csv' | |
PAD_WORD = '' | |
def vectorize_words(words, word2index, max_len): | |
return [word2index[w] for w in words] + [word2index[PAD_WORD]] * (max_len - len(words)) | |
class SynonymyDataset(torch.utils.data.Dataset): | |
def __init__(self, computed_params): | |
self.samples = [] | |
self.word2index = computed_params['word2index'] | |
self.max_len = computed_params['max_len'] | |
def append(self, sample): | |
self.samples.append(sample) | |
def __len__(self): | |
return len(self.samples) | |
def __getitem__(self, idx): | |
if torch.is_tensor(idx): | |
idx = idx.tolist() | |
words1, words2, label = self.samples[idx] | |
#return self.samples[idx] | |
idx1 = vectorize_words(words1, self.word2index, self.max_len) | |
idx2 = vectorize_words(words2, self.word2index, self.max_len) | |
return torch.tensor(idx1, dtype=torch.long), torch.tensor(idx2, dtype=torch.long), torch.FloatTensor([label]) | |
def load_samples(input_path): | |
df = pd.read_csv(input_path, encoding='utf-8', delimiter='\t', quoting=3) | |
lexicon = set() | |
MAX_LEN = 15 | |
max_len = 0 | |
for phrase1, phrase2, label in zip(df['premise'].values, df['question'].values, df['relevance'].values): | |
words1 = phrase1.split() | |
words2 = phrase2.split() | |
if len(words1) > MAX_LEN or len(words2) > MAX_LEN: | |
continue | |
max_len = max(max_len, len(words1), len(words2)) | |
lexicon.update(words1) | |
lexicon.update(words2) | |
word2index = dict((w, i) for i, w in enumerate(lexicon, start=1)) | |
word2index[PAD_WORD] = 0 | |
nb_words = len(word2index) | |
params = {'nb_words': nb_words, 'word2index': word2index, 'max_len': max_len} | |
samples = SynonymyDataset(params) | |
for phrase1, phrase2, label in zip(df['premise'].values, df['question'].values, df['relevance'].values): | |
words1 = phrase1.split() | |
words2 = phrase2.split() | |
if len(words1) > MAX_LEN or len(words2) > MAX_LEN: | |
continue | |
samples.append((words1, words2, label)) | |
del df | |
return samples, params | |
def vectorize_samples(batch_samples, computed_params): | |
batch_size = len(batch_samples) | |
x1 = np.zeros((batch_size, computed_params['max_len']), dtype=np.int32) | |
x2 = np.zeros((batch_size, computed_params['max_len']), dtype=np.int32) | |
y = np.zeros((batch_size, 1), dtype=np.bool) | |
for isample, sample in enumerate(batch_samples): | |
x1[isample, :] = sample[0] | |
x2[isample, :] = sample[1] | |
y[isample] = sample[2] | |
return torch.tensor(x1, dtype=torch.long), torch.tensor(x2, dtype=torch.long), torch.tensor(y, dtype=torch.float) | |
class SynonymyDetector(nn.Module): | |
def __init__(self, computed_params): | |
super(SynonymyDetector, self).__init__() | |
self.max_len = computed_params['max_len'] | |
vocab_size = computed_params['nb_words'] | |
embedding_dim = 200 | |
self.embed = nn.Embedding(vocab_size, embedding_dim=embedding_dim, padding_idx=0) | |
#self.conv2 = nn.Conv1d(in_channels=embedding_dim, out_channels=100, kernel_size=2) | |
rnn_hidden_size = 200 | |
self.rnn = nn.LSTM(input_size=embedding_dim, | |
hidden_size=rnn_hidden_size, | |
#dropout=0.0, | |
#num_layers=1, #self.num_layers, | |
bidirectional=True, | |
batch_first=True) | |
self.fc1 = nn.Linear(rnn_hidden_size*2*4, 1) | |
#self.fc1 = nn.Linear(rnn_hidden_size*2*1, 1) | |
#self.fc2 = nn.Linear(100, 1) | |
def encode(self, x): | |
path = self.embed(x) | |
#path1 = self.conv1(x1) | |
#path1 = F.max_pool1d(path1, 2) | |
#path1 = path1.view(self.max_len, 1, -1) | |
out, (hidden, cell) = self.rnn(path) | |
res = out[:, -1, :] | |
return res | |
def forward(self, x1, x2): | |
path1 = self.encode(x1) | |
path2 = self.encode(x2) | |
merged = torch.cat((path1, path2, torch.abs(path1 - path2), path1 * path2), dim=-1) | |
#merged = path1 * path2 | |
merged = self.fc1(merged) | |
#merged = torch.relu(merged) | |
#merged = self.fc2(merged) | |
output = torch.sigmoid(merged) | |
#x = F.relu(x) | |
#x = self.conv2(x) | |
#x = F.max_pool2d(x, 2) | |
#x = self.dropout1(x) | |
#x = torch.flatten(x, 1) | |
#x = self.fc1(x) | |
#x = F.relu(x) | |
#x = self.dropout2(x) | |
#x = self.fc2(x) | |
#output = F.log_softmax(x, dim=1) | |
return output | |
def train(model, device, batch_generator, optimizer, epoch): | |
model.train() | |
for x1, x2, y in tqdm.tqdm(batch_generator, desc='epoch {}'.format(epoch), total=len(batch_generator)): | |
x1 = x1.to(device) | |
x2 = x2.to(device) | |
y = y.to(device) | |
optimizer.zero_grad() | |
output = model(x1, x2) | |
#loss = F.nll_loss(output, y) | |
loss = F.binary_cross_entropy(output, y) | |
loss.backward() | |
optimizer.step() | |
def test(model, device, batch_generator): | |
model.eval() | |
test_loss = 0 | |
correct = 0 | |
nb_records = 0 | |
with torch.no_grad(): | |
for x1, x2, y in batch_generator: | |
x1 = x1.to(device) | |
x2 = x2.to(device) | |
y = y.to(device) | |
output = model(x1, x2) | |
test_loss += F.binary_cross_entropy(output, y, reduction='sum').item() # sum up batch loss | |
#pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability | |
pred = output > 0.5 | |
y2 = y.view_as(pred) > 0.5 | |
correct += pred.eq(y2).sum().item() | |
nb_records += y.shape[0] | |
test_loss /= nb_records | |
acc = correct / float(nb_records) | |
print('\nTest set: Average loss: {:.4f}, Accuracy: {:.0f}\n'.format(test_loss, acc)) | |
return acc | |
def validate(model, device, batch_generator): | |
model.eval() | |
y_true2 = [] | |
y_pred2 = [] | |
with torch.no_grad(): | |
for x1, x2, y_true in batch_generator: | |
x1 = x1.to(device) | |
x2 = x2.to(device) | |
y_pred = model(x1, x2) | |
y_true2.append(y_true.cpu()) | |
y_pred2.append(y_pred.cpu() > 0.5) | |
y_true = np.vstack(y_true2) | |
y_pred = np.vstack(y_pred2) | |
f1 = sklearn.metrics.f1_score(y_true=y_true, y_pred=y_pred) | |
return f1 | |
if __name__ == '__main__': | |
samples, computed_params = load_samples(dataset_path) | |
print('{} samples'.format(len(samples))) | |
print('max_len={} words'.format(computed_params['max_len'])) | |
print('{} words in lexicon'.format(computed_params['nb_words'])) | |
batch_size = 150 | |
epochs = 100 | |
train_samples, testval_samples = train_test_split(samples, test_size=0.2) | |
test_samples, val_samples = train_test_split(testval_samples, test_size=0.5) | |
train_generator = torch.utils.data.DataLoader(train_samples, batch_size=batch_size) | |
test_generator = torch.utils.data.DataLoader(test_samples, batch_size=batch_size) | |
#torch.manual_seed(args.seed) | |
use_cuda = torch.cuda.is_available() | |
device = torch.device("cuda" if use_cuda else "cpu") | |
model = SynonymyDetector(computed_params).to(device) | |
optimizer = optim.Adadelta(model.parameters(), lr=0.1) | |
#scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) | |
weights_path = "synonymy_model.pt" | |
best_acc = 0.0 | |
nb_bad_epochs = 0 | |
for epoch in range(1, epochs+1): | |
train(model, device, train_generator, optimizer, epoch) | |
acc = test(model, device, test_generator) | |
#scheduler.step() | |
if acc > best_acc: | |
print('NEW BEST ACC={}'.format(acc)) | |
best_acc = acc | |
nb_bad_epochs = 0 | |
print('Saving model to "{}"...'.format(weights_path)) | |
torch.save(model.state_dict(), weights_path) | |
else: | |
nb_bad_epochs += 1 | |
print('No improvement over current best_acc={}'.format(best_acc)) | |
if nb_bad_epochs >= 10: | |
print('Early stopping on epoch={} best_acc={}'.format(epoch, best_acc)) | |
break | |
model.load_state_dict(torch.load(weights_path)) | |
f1 = validate(model, device, val_samples) | |
print('validation f1={}'.format(f1)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment