Skip to content

Instantly share code, notes, and snippets.

@Koziev
Created June 11, 2021 04:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Koziev/1ce6c9db4d73563ef3274e482c3fd8e0 to your computer and use it in GitHub Desktop.
Save Koziev/1ce6c9db4d73563ef3274e482c3fd8e0 to your computer and use it in GitHub Desktop.
Training the sentence autoencoder on the top of DeepPavlov's BERT token embeddings
"""
Эксперимент с моделью несимметричного автоэнкодера с энкодером на базе претренированной модели BERT
"""
import io
import os
import random
import numpy as np
import sklearn.model_selection
from colorclass import Color
import terminaltables
from keras_bert import get_pretrained, PretrainedList, get_checkpoint_paths
from keras_bert import load_vocabulary, load_trained_model_from_checkpoint, Tokenizer, get_checkpoint_paths
import keras
from keras import Model
from keras_bert import get_base_dict, get_model, gen_batch_inputs
from keras_bert import extract_embeddings
from keras_bert import Tokenizer
from keras import layers
from keras.layers import RepeatVector
from keras.layers import Bidirectional, TimeDistributed
# Путь к скачанной модели (DeepPavlov's)
# model_dir = '/media/inkoziev/corpora/EmbeddingModels/BERT_multilingual/model/multi_cased_L-12_H-768_A-12'
model_dir = '/home/inkoziev/corpora/EmbeddingModels/BERT_multilingual/model/rubert_cased_L-12_H-768_A-12_v1'
def vectorize(samples, tokenizer, token_dict, otoken2index, max_seq_len):
nb_samples = len(samples)
X_tok = np.zeros((nb_samples, max_seq_len), dtype=np.int32)
X_seg = np.zeros((nb_samples, max_seq_len), dtype=np.int32)
y = np.zeros((nb_samples, max_seq_len,), dtype=np.int32)
for isample, sample in enumerate(samples):
#tokens0 = split_str(sample)
#tokens = ['[CLS]'] + tokens0 + ['[SEP]']
indices, segments = tokenizer.encode(first=sample, max_len=max_seq_len)
X_tok[isample, :] = indices
X_seg[isample, :] = segments
tokens2 = tokenizer.tokenize(sample)[1:] # начальный токен [CLS] исключим
for itoken, token in enumerate(tokens2):
y[isample, itoken] = otoken2index[token]
return X_tok, X_seg, y
def ngrams(s):
return [s1+s2+s3 for (s1, s2, s3) in zip(s, s[1:], s[2:])]
def jaccard(s1, s2):
s1 = set(ngrams(s1))
s2 = set(ngrams(s2))
return float(len(s1&s2))/float(1e-8+len(s1|s2))
class VizualizeCallback(keras.callbacks.Callback):
"""
После каждой эпохи обучения делаем сэмплинг образцов из текущей модели,
чтобы видеть общее качество.
"""
def __init__(self, model, test_samples, tokenizer, token_dict, otoken2index, max_seq_len):
self.model = model
self.test_samples = test_samples
self.otoken2index = otoken2index
self.index2otoken = dict((i, t) for t, i in otoken2index.items())
self.max_seq_len = max_seq_len
self.X_tok, self.X_seg, self.y = vectorize(test_samples, tokenizer, token_dict, otoken2index, max_seq_len)
self.val_history = []
self.best_jaccard_score = 0.0
def on_epoch_end(self, batch, logs={}):
pred_samples = []
bs = 64
for i in range(0, len(self.test_samples), bs):
bs2 = min(bs, len(self.test_samples) - i)
y_pred = self.model.predict_on_batch(x=(self.X_tok[i: i+bs2], self.X_seg[i: i+bs2]))
for true_text, y in zip(self.test_samples, y_pred):
y = np.argmax(y, axis=-1)
pred_tokens = [self.index2otoken[i] for i in y]
if '[SEP]' in pred_tokens:
pred_tokens = pred_tokens[:pred_tokens.index('[SEP]')]
pred_text = ''.join(pred_tokens).replace('##', ' ').strip()
pred_samples.append((true_text, pred_text))
r_samples = random.sample(pred_samples, k=10)
table = ['true_output predicted_output'.split()]
for true_sample, pred_sample in r_samples:
if true_sample == pred_sample:
# выдача сетки полностью верная
output2 = Color('{autogreen}' + pred_sample + '{/autogreen}')
elif jaccard(true_sample, pred_sample) > 0.5:
# выдача сетки частично совпала с требуемой строкой
output2 = Color('{autoyellow}' + pred_sample + '{/autoyellow}')
else:
# неправильная выдача сетки
output2 = Color('{autored}' + pred_sample + '{/autored}')
table.append((true_sample, output2))
table = terminaltables.AsciiTable(table)
print(table.table)
success_rate = sum((true_sample == pred_sample) for true_sample, pred_sample in pred_samples) / float(len(self.test_samples))
mean_jac = np.mean([jaccard(true_sample, pred_sample) for true_sample, pred_sample in pred_samples])
self.val_history.append((success_rate, mean_jac))
print('{}% samples are inferred without loss, mean jaccard score={}'.format(success_rate*100.0, mean_jac))
if mean_jac > self.best_jaccard_score:
self.best_jaccard_score = mean_jac
with io.open(os.path.join(tmp_dir, 'pretrained_bert_autoencoder.output.txt'), 'w', encoding='utf-8') as wrt:
s = table.table
for k in '\x1b[91m \x1b[92m \x1b[93m \x1b[39m'.split():
s = s.replace(k, '')
wrt.write(s+'\n')
if __name__ == '__main__':
tmp_dir = '../tmp'
token_dict = load_vocabulary(os.path.join(model_dir, 'vocab.txt'))
tokenizer = Tokenizer(token_dict)
all_sents = set()
max_seq_len = 0
all_tokens = set()
print('Preprocessing the samples...')
with io.open('../tmp/assemble_training_corpus_for_bert.txt', 'r', encoding='utf-8') as rdr:
for line in rdr:
s = line.strip()
if s:
#tokens0 = split_str(s)
tokens = tokenizer.tokenize(s)
all_tokens.update(tokens)
max_seq_len = max(max_seq_len, len(tokens))
all_sents.add(s)
all_sents = list(all_sents)
# НАЧАЛО ОТЛАДКИ
all_sents = all_sents[:200000]
# КОНЕЦ ОТЛАДКИ
print('max_seq_len={}'.format(max_seq_len))
otoken2index = dict((t, i) for i, t in enumerate(all_tokens, start=1))
otoken2index['[PAD'] = 0
model = load_trained_model_from_checkpoint(os.path.join(model_dir, 'bert_config.json'),
os.path.join(model_dir, 'bert_model.ckpt'),
seq_len=max_seq_len,
output_layer_num=1)
#model.summary(line_length=120)
inputs = model.inputs
bert_output_layer = model.output
#print(bert_output_layer.shape)
bert_token_dim = bert_output_layer.shape[2]
print('bert_token_dim={}'.format(bert_token_dim))
emb_size = 64 # bert_token_dim*2
encoder = layers.LSTM(units=emb_size, return_sequences=False)(bert_output_layer)
decoder = RepeatVector(max_seq_len)(encoder)
decoder = layers.LSTM(emb_size, return_sequences=True)(decoder)
decoder = TimeDistributed(layers.Dense(units=len(otoken2index), activation='softmax'), name='output')(decoder)
model = keras.Model(inputs, decoder, name="autodecoder")
model.compile(loss='sparse_categorical_crossentropy', optimizer=keras.optimizers.Adam())
model.summary()
with io.open(os.path.join(tmp_dir, 'pretrained_bert_autoencoder.model_summary.txt'), 'w', encoding='utf-8') as wrt:
model.summary(line_length=120, print_fn=lambda s: wrt.write(s+'\n'))
train_samples, viz_samples = sklearn.model_selection.train_test_split(all_sents, test_size=1000, random_state=123456)
nb_samples = len(all_sents)
print('Vectorization of {} samples...'.format(nb_samples))
X_tok, X_seg, y = vectorize(train_samples, tokenizer, token_dict, otoken2index, max_seq_len)
weights_path = '../tmp/pretrained_bert_autoencoder.weights'
model_checkpoint = keras.callbacks.ModelCheckpoint(weights_path,
monitor='val_loss',
verbose=1,
save_best_only=True,
save_weights_only=True,
mode='auto')
early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss',
patience=5, verbose=1, mode='auto', restore_best_weights=True)
viz = VizualizeCallback(model, viz_samples, tokenizer, token_dict, otoken2index, max_seq_len)
print('Start training...')
hist = model.fit(x=(X_tok, X_seg), y=y,
epochs=50, validation_split=0.1,
callbacks=[viz, model_checkpoint, early_stopping],
batch_size=64,
verbose=2)
with io.open(os.path.join(tmp_dir, 'pretrained_bert_autoencoder.learning_curve.tsv'), 'w', encoding='utf-8') as wrt:
wrt.write('epoch\tacc_rate\tmean_jaccard\tval_loss\n')
for epoch, ((acc_rate, mean_jacc), val_loss) in enumerate(zip(viz.val_history, hist.history['val_loss']), start=1):
wrt.write('{}\t{}\t{}\t{}\n'.format(epoch, acc_rate, mean_jacc, val_loss))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment