Created
March 13, 2020 11:00
-
-
Save Koziev/52306047949e07f9fa682c5194d72f4c to your computer and use it in GitHub Desktop.
Бинарный классификатор на Keras с BERT для определения перефразировок
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
import pandas as pd | |
import os | |
import numpy as np | |
import logging | |
import gc | |
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score | |
from sklearn.model_selection import train_test_split | |
from sklearn.model_selection import cross_val_score | |
from sklearn.metrics import roc_auc_score | |
import sklearn.preprocessing | |
import keras.callbacks | |
from keras import backend as K | |
from keras.callbacks import ModelCheckpoint, EarlyStopping | |
from keras.layers import Conv1D, GlobalMaxPooling1D, GlobalAveragePooling1D, AveragePooling1D | |
from keras.layers import Input | |
from keras.layers import Lambda | |
from keras.layers import recurrent | |
from keras.layers.core import Dense | |
from keras.layers.merge import concatenate, add, multiply | |
from keras.layers.wrappers import Bidirectional | |
from keras.models import Model | |
from keras.models import model_from_json | |
from keras.layers.normalization import BatchNormalization | |
from keras.layers import Flatten | |
import keras.regularizers | |
from experiments.bert_embedder.bert_embedder2 import BERTEmbedder | |
data_folder = os.path.expanduser('~/polygon/chatbot/data') | |
def get_params_str(model_params): | |
return ' '.join('{}={}'.format(k, v) for (k, v) in model_params.items()) | |
def vectorize_data(embedder, sample_strings1, sample_strings2): | |
X1_data = embedder(sample_strings1) | |
X2_data = embedder(sample_strings2) | |
if True: | |
X3 = np.subtract(X1_data, X2_data) | |
X4 = np.multiply(X1_data, X2_data) | |
X_data = np.hstack((X3, X4)) | |
else: | |
X_data = np.hstack((X1_data, X2_data)) | |
return X_data | |
tmp_dir = '../../../tmp' | |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(message)s') | |
lf = logging.FileHandler(os.path.join(tmp_dir, 'synonymy_detector_via_bert.log'), mode='w') | |
lf.setLevel(logging.INFO) | |
formatter = logging.Formatter('%(asctime)s %(message)s') | |
lf.setFormatter(formatter) | |
logging.getLogger('').addHandler(lf) | |
max_seq_len=20 | |
vectorized_path = os.path.join(tmp_dir, 'Xy_synonymy_detector_via_bert3.npy') | |
if not os.path.exists(vectorized_path): | |
# model_path = os.path.expanduser('~/corpora/EmbeddingModels/BERT_multilingual/model/multi_cased_L-12_H-768_A-12') | |
model_path = os.path.expanduser('~/corpora/EmbeddingModels/BERT_multilingual/model/rubert_cased_L-12_H-768_A-12_v1') | |
embedder = BERTEmbedder(model_path=model_path, seq_len=max_seq_len) | |
sample_strings1 = [] | |
sample_strings2 = [] | |
sample_ys = [] | |
df = pd.read_csv(os.path.join(data_folder, 'synonymy_dataset.csv'), encoding='utf-8', delimiter='\t', quoting=3) | |
for i, r in df.iterrows(): | |
sample = (r['premise'], r['question']) | |
label = r['relevance'] | |
sample_strings1.append(sample[0]) | |
sample_strings2.append(sample[1]) | |
sample_ys.append(label) | |
if len(sample_ys) > 100000: | |
break | |
X_data = vectorize_data(embedder, sample_strings1, sample_strings2) | |
sklearn.preprocessing.normalize(X_data, norm='l2', copy=False) | |
y_data = np.asarray(sample_ys) | |
del embedder | |
gc.collect() | |
X_train, X_test, y_train, y_test = train_test_split(X_data, y_data, test_size=0.2, random_state=123456789) | |
print('X_train.shape={}'.format(X_train.shape)) | |
with open(vectorized_path, 'wb') as f: | |
np.savez(f, X_train=X_train, X_test=X_test, y_train=y_train, y_test=y_test) | |
exit(0) | |
else: | |
# загрузим ранее подготовленные векторизованные датасеты | |
with open(vectorized_path, 'rb') as f: | |
npzfile = np.load(f) | |
X_train = npzfile['X_train'] | |
X_test = npzfile['X_test'] | |
y_train = npzfile['y_train'] | |
y_test = npzfile['y_test'] | |
input = Input(shape=(X_train.shape[1],), dtype='float32', name='input') | |
net = input | |
net = Dense(units=1000, activation='relu')(net) | |
#net = Dense(units=10, activation='relu')(net) | |
net = Dense(units=1)(net) | |
model = Model(inputs=[input], outputs=net) | |
model.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=['accuracy']) | |
model.summary() | |
weights_path = '/home/inkoziev/polygon/chatbot/tmp/synonymy_detector_via_bert3.model' | |
model_checkpoint = ModelCheckpoint(weights_path, monitor='val_acc', | |
verbose=1, | |
save_best_only=True, | |
mode='auto') | |
early_stopping = EarlyStopping(monitor='val_loss', | |
patience=10, | |
verbose=1, | |
mode='auto') | |
model.fit(x=X_train, y=y_train, validation_data=(X_test, y_test), | |
batch_size=50, epochs=100, verbose=2, | |
callbacks=[model_checkpoint, early_stopping]) | |
model.load_weights(weights_path) | |
y_pred = model.predict(X_test) | |
print('auc score={}'.format(roc_auc_score(y_true=y_test, y_score=y_pred))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment