Created
June 11, 2021 04:15
-
-
Save Koziev/1ce6c9db4d73563ef3274e482c3fd8e0 to your computer and use it in GitHub Desktop.
Training the sentence autoencoder on the top of DeepPavlov's BERT token embeddings
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Эксперимент с моделью несимметричного автоэнкодера с энкодером на базе претренированной модели BERT | |
""" | |
import io | |
import os | |
import random | |
import numpy as np | |
import sklearn.model_selection | |
from colorclass import Color | |
import terminaltables | |
from keras_bert import get_pretrained, PretrainedList, get_checkpoint_paths | |
from keras_bert import load_vocabulary, load_trained_model_from_checkpoint, Tokenizer, get_checkpoint_paths | |
import keras | |
from keras import Model | |
from keras_bert import get_base_dict, get_model, gen_batch_inputs | |
from keras_bert import extract_embeddings | |
from keras_bert import Tokenizer | |
from keras import layers | |
from keras.layers import RepeatVector | |
from keras.layers import Bidirectional, TimeDistributed | |
# Путь к скачанной модели (DeepPavlov's) | |
# model_dir = '/media/inkoziev/corpora/EmbeddingModels/BERT_multilingual/model/multi_cased_L-12_H-768_A-12' | |
model_dir = '/home/inkoziev/corpora/EmbeddingModels/BERT_multilingual/model/rubert_cased_L-12_H-768_A-12_v1' | |
def vectorize(samples, tokenizer, token_dict, otoken2index, max_seq_len): | |
nb_samples = len(samples) | |
X_tok = np.zeros((nb_samples, max_seq_len), dtype=np.int32) | |
X_seg = np.zeros((nb_samples, max_seq_len), dtype=np.int32) | |
y = np.zeros((nb_samples, max_seq_len,), dtype=np.int32) | |
for isample, sample in enumerate(samples): | |
#tokens0 = split_str(sample) | |
#tokens = ['[CLS]'] + tokens0 + ['[SEP]'] | |
indices, segments = tokenizer.encode(first=sample, max_len=max_seq_len) | |
X_tok[isample, :] = indices | |
X_seg[isample, :] = segments | |
tokens2 = tokenizer.tokenize(sample)[1:] # начальный токен [CLS] исключим | |
for itoken, token in enumerate(tokens2): | |
y[isample, itoken] = otoken2index[token] | |
return X_tok, X_seg, y | |
def ngrams(s): | |
return [s1+s2+s3 for (s1, s2, s3) in zip(s, s[1:], s[2:])] | |
def jaccard(s1, s2): | |
s1 = set(ngrams(s1)) | |
s2 = set(ngrams(s2)) | |
return float(len(s1&s2))/float(1e-8+len(s1|s2)) | |
class VizualizeCallback(keras.callbacks.Callback): | |
""" | |
После каждой эпохи обучения делаем сэмплинг образцов из текущей модели, | |
чтобы видеть общее качество. | |
""" | |
def __init__(self, model, test_samples, tokenizer, token_dict, otoken2index, max_seq_len): | |
self.model = model | |
self.test_samples = test_samples | |
self.otoken2index = otoken2index | |
self.index2otoken = dict((i, t) for t, i in otoken2index.items()) | |
self.max_seq_len = max_seq_len | |
self.X_tok, self.X_seg, self.y = vectorize(test_samples, tokenizer, token_dict, otoken2index, max_seq_len) | |
self.val_history = [] | |
self.best_jaccard_score = 0.0 | |
def on_epoch_end(self, batch, logs={}): | |
pred_samples = [] | |
bs = 64 | |
for i in range(0, len(self.test_samples), bs): | |
bs2 = min(bs, len(self.test_samples) - i) | |
y_pred = self.model.predict_on_batch(x=(self.X_tok[i: i+bs2], self.X_seg[i: i+bs2])) | |
for true_text, y in zip(self.test_samples, y_pred): | |
y = np.argmax(y, axis=-1) | |
pred_tokens = [self.index2otoken[i] for i in y] | |
if '[SEP]' in pred_tokens: | |
pred_tokens = pred_tokens[:pred_tokens.index('[SEP]')] | |
pred_text = ''.join(pred_tokens).replace('##', ' ').strip() | |
pred_samples.append((true_text, pred_text)) | |
r_samples = random.sample(pred_samples, k=10) | |
table = ['true_output predicted_output'.split()] | |
for true_sample, pred_sample in r_samples: | |
if true_sample == pred_sample: | |
# выдача сетки полностью верная | |
output2 = Color('{autogreen}' + pred_sample + '{/autogreen}') | |
elif jaccard(true_sample, pred_sample) > 0.5: | |
# выдача сетки частично совпала с требуемой строкой | |
output2 = Color('{autoyellow}' + pred_sample + '{/autoyellow}') | |
else: | |
# неправильная выдача сетки | |
output2 = Color('{autored}' + pred_sample + '{/autored}') | |
table.append((true_sample, output2)) | |
table = terminaltables.AsciiTable(table) | |
print(table.table) | |
success_rate = sum((true_sample == pred_sample) for true_sample, pred_sample in pred_samples) / float(len(self.test_samples)) | |
mean_jac = np.mean([jaccard(true_sample, pred_sample) for true_sample, pred_sample in pred_samples]) | |
self.val_history.append((success_rate, mean_jac)) | |
print('{}% samples are inferred without loss, mean jaccard score={}'.format(success_rate*100.0, mean_jac)) | |
if mean_jac > self.best_jaccard_score: | |
self.best_jaccard_score = mean_jac | |
with io.open(os.path.join(tmp_dir, 'pretrained_bert_autoencoder.output.txt'), 'w', encoding='utf-8') as wrt: | |
s = table.table | |
for k in '\x1b[91m \x1b[92m \x1b[93m \x1b[39m'.split(): | |
s = s.replace(k, '') | |
wrt.write(s+'\n') | |
if __name__ == '__main__': | |
tmp_dir = '../tmp' | |
token_dict = load_vocabulary(os.path.join(model_dir, 'vocab.txt')) | |
tokenizer = Tokenizer(token_dict) | |
all_sents = set() | |
max_seq_len = 0 | |
all_tokens = set() | |
print('Preprocessing the samples...') | |
with io.open('../tmp/assemble_training_corpus_for_bert.txt', 'r', encoding='utf-8') as rdr: | |
for line in rdr: | |
s = line.strip() | |
if s: | |
#tokens0 = split_str(s) | |
tokens = tokenizer.tokenize(s) | |
all_tokens.update(tokens) | |
max_seq_len = max(max_seq_len, len(tokens)) | |
all_sents.add(s) | |
all_sents = list(all_sents) | |
# НАЧАЛО ОТЛАДКИ | |
all_sents = all_sents[:200000] | |
# КОНЕЦ ОТЛАДКИ | |
print('max_seq_len={}'.format(max_seq_len)) | |
otoken2index = dict((t, i) for i, t in enumerate(all_tokens, start=1)) | |
otoken2index['[PAD'] = 0 | |
model = load_trained_model_from_checkpoint(os.path.join(model_dir, 'bert_config.json'), | |
os.path.join(model_dir, 'bert_model.ckpt'), | |
seq_len=max_seq_len, | |
output_layer_num=1) | |
#model.summary(line_length=120) | |
inputs = model.inputs | |
bert_output_layer = model.output | |
#print(bert_output_layer.shape) | |
bert_token_dim = bert_output_layer.shape[2] | |
print('bert_token_dim={}'.format(bert_token_dim)) | |
emb_size = 64 # bert_token_dim*2 | |
encoder = layers.LSTM(units=emb_size, return_sequences=False)(bert_output_layer) | |
decoder = RepeatVector(max_seq_len)(encoder) | |
decoder = layers.LSTM(emb_size, return_sequences=True)(decoder) | |
decoder = TimeDistributed(layers.Dense(units=len(otoken2index), activation='softmax'), name='output')(decoder) | |
model = keras.Model(inputs, decoder, name="autodecoder") | |
model.compile(loss='sparse_categorical_crossentropy', optimizer=keras.optimizers.Adam()) | |
model.summary() | |
with io.open(os.path.join(tmp_dir, 'pretrained_bert_autoencoder.model_summary.txt'), 'w', encoding='utf-8') as wrt: | |
model.summary(line_length=120, print_fn=lambda s: wrt.write(s+'\n')) | |
train_samples, viz_samples = sklearn.model_selection.train_test_split(all_sents, test_size=1000, random_state=123456) | |
nb_samples = len(all_sents) | |
print('Vectorization of {} samples...'.format(nb_samples)) | |
X_tok, X_seg, y = vectorize(train_samples, tokenizer, token_dict, otoken2index, max_seq_len) | |
weights_path = '../tmp/pretrained_bert_autoencoder.weights' | |
model_checkpoint = keras.callbacks.ModelCheckpoint(weights_path, | |
monitor='val_loss', | |
verbose=1, | |
save_best_only=True, | |
save_weights_only=True, | |
mode='auto') | |
early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', | |
patience=5, verbose=1, mode='auto', restore_best_weights=True) | |
viz = VizualizeCallback(model, viz_samples, tokenizer, token_dict, otoken2index, max_seq_len) | |
print('Start training...') | |
hist = model.fit(x=(X_tok, X_seg), y=y, | |
epochs=50, validation_split=0.1, | |
callbacks=[viz, model_checkpoint, early_stopping], | |
batch_size=64, | |
verbose=2) | |
with io.open(os.path.join(tmp_dir, 'pretrained_bert_autoencoder.learning_curve.tsv'), 'w', encoding='utf-8') as wrt: | |
wrt.write('epoch\tacc_rate\tmean_jaccard\tval_loss\n') | |
for epoch, ((acc_rate, mean_jacc), val_loss) in enumerate(zip(viz.val_history, hist.history['val_loss']), start=1): | |
wrt.write('{}\t{}\t{}\t{}\n'.format(epoch, acc_rate, mean_jacc, val_loss)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment