Skip to content

Instantly share code, notes, and snippets.

@Koziev
Created March 13, 2020 11:00
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Koziev/52306047949e07f9fa682c5194d72f4c to your computer and use it in GitHub Desktop.
Save Koziev/52306047949e07f9fa682c5194d72f4c to your computer and use it in GitHub Desktop.
Бинарный классификатор на Keras с BERT для определения перефразировок
# coding: utf-8
import pandas as pd
import os
import numpy as np
import logging
import gc
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
from sklearn.metrics import roc_auc_score
import sklearn.preprocessing
import keras.callbacks
from keras import backend as K
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.layers import Conv1D, GlobalMaxPooling1D, GlobalAveragePooling1D, AveragePooling1D
from keras.layers import Input
from keras.layers import Lambda
from keras.layers import recurrent
from keras.layers.core import Dense
from keras.layers.merge import concatenate, add, multiply
from keras.layers.wrappers import Bidirectional
from keras.models import Model
from keras.models import model_from_json
from keras.layers.normalization import BatchNormalization
from keras.layers import Flatten
import keras.regularizers
from experiments.bert_embedder.bert_embedder2 import BERTEmbedder
data_folder = os.path.expanduser('~/polygon/chatbot/data')
def get_params_str(model_params):
return ' '.join('{}={}'.format(k, v) for (k, v) in model_params.items())
def vectorize_data(embedder, sample_strings1, sample_strings2):
X1_data = embedder(sample_strings1)
X2_data = embedder(sample_strings2)
if True:
X3 = np.subtract(X1_data, X2_data)
X4 = np.multiply(X1_data, X2_data)
X_data = np.hstack((X3, X4))
else:
X_data = np.hstack((X1_data, X2_data))
return X_data
tmp_dir = '../../../tmp'
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s %(message)s')
lf = logging.FileHandler(os.path.join(tmp_dir, 'synonymy_detector_via_bert.log'), mode='w')
lf.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s %(message)s')
lf.setFormatter(formatter)
logging.getLogger('').addHandler(lf)
max_seq_len=20
vectorized_path = os.path.join(tmp_dir, 'Xy_synonymy_detector_via_bert3.npy')
if not os.path.exists(vectorized_path):
# model_path = os.path.expanduser('~/corpora/EmbeddingModels/BERT_multilingual/model/multi_cased_L-12_H-768_A-12')
model_path = os.path.expanduser('~/corpora/EmbeddingModels/BERT_multilingual/model/rubert_cased_L-12_H-768_A-12_v1')
embedder = BERTEmbedder(model_path=model_path, seq_len=max_seq_len)
sample_strings1 = []
sample_strings2 = []
sample_ys = []
df = pd.read_csv(os.path.join(data_folder, 'synonymy_dataset.csv'), encoding='utf-8', delimiter='\t', quoting=3)
for i, r in df.iterrows():
sample = (r['premise'], r['question'])
label = r['relevance']
sample_strings1.append(sample[0])
sample_strings2.append(sample[1])
sample_ys.append(label)
if len(sample_ys) > 100000:
break
X_data = vectorize_data(embedder, sample_strings1, sample_strings2)
sklearn.preprocessing.normalize(X_data, norm='l2', copy=False)
y_data = np.asarray(sample_ys)
del embedder
gc.collect()
X_train, X_test, y_train, y_test = train_test_split(X_data, y_data, test_size=0.2, random_state=123456789)
print('X_train.shape={}'.format(X_train.shape))
with open(vectorized_path, 'wb') as f:
np.savez(f, X_train=X_train, X_test=X_test, y_train=y_train, y_test=y_test)
exit(0)
else:
# загрузим ранее подготовленные векторизованные датасеты
with open(vectorized_path, 'rb') as f:
npzfile = np.load(f)
X_train = npzfile['X_train']
X_test = npzfile['X_test']
y_train = npzfile['y_train']
y_test = npzfile['y_test']
input = Input(shape=(X_train.shape[1],), dtype='float32', name='input')
net = input
net = Dense(units=1000, activation='relu')(net)
#net = Dense(units=10, activation='relu')(net)
net = Dense(units=1)(net)
model = Model(inputs=[input], outputs=net)
model.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=['accuracy'])
model.summary()
weights_path = '/home/inkoziev/polygon/chatbot/tmp/synonymy_detector_via_bert3.model'
model_checkpoint = ModelCheckpoint(weights_path, monitor='val_acc',
verbose=1,
save_best_only=True,
mode='auto')
early_stopping = EarlyStopping(monitor='val_loss',
patience=10,
verbose=1,
mode='auto')
model.fit(x=X_train, y=y_train, validation_data=(X_test, y_test),
batch_size=50, epochs=100, verbose=2,
callbacks=[model_checkpoint, early_stopping])
model.load_weights(weights_path)
y_pred = model.predict(X_test)
print('auc score={}'.format(roc_auc_score(y_true=y_test, y_score=y_pred)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment