Last active
May 23, 2024 06:11
-
-
Save jordanspooner/9020113d4e6f0df2f2ce1f70df5f1dfe to your computer and use it in GitHub Desktop.
BiLSTM Architecture
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#================================= | |
# Sentence-level QE -- BiRNN model | |
#================================= | |
# | |
## Inputs: | |
# 1. Sentences in src language (shape: (mini_batch_size, line_words)) | |
# 2. Parallel machine-translated documents (shape: (mini_batch_size, line_words)) | |
# | |
## Output: | |
# 1. Sentence quality scores (shape: (mini_batch_size,)) | |
# | |
## Summary of the model: | |
# The sententence-level representations of both the SRC and the MT are created using two bi-directional RNNs. | |
# Those representations are then concatenated at the word level, and the sentence representation is a weighted sum of its words. | |
# We apply the following attention function computing a normalized weight for each hidden state of an RNN h_j: | |
# alpha_j = exp(W_a*h_j)/sum_k exp(W_a*h_k) | |
# The resulting sentence vector is thus a weighted sum of word vectors: | |
# v = sum_j alpha_j*h_j | |
# Sentence vectors are then directly used for making classification decisions. | |
def EncSent(self, params): | |
src_words = Input(name=self.ids_inputs[0], | |
batch_shape=tuple([None, params['MAX_INPUT_TEXT_LEN']]), dtype='int32') | |
src_embedding = Embedding(params['INPUT_VOCABULARY_SIZE'], params['TARGET_TEXT_EMBEDDING_SIZE'], | |
name='src_word_embedding', | |
embeddings_regularizer=l2(params['WEIGHT_DECAY']), | |
embeddings_initializer=params['INIT_FUNCTION'], | |
trainable=self.trainable, | |
mask_zero=True)(src_words) | |
src_embedding = Regularize(src_embedding, params, trainable=self.trainable, name='src_state') | |
src_annotations = Bidirectional(eval(params['ENCODER_RNN_TYPE'])(params['ENCODER_HIDDEN_SIZE'], | |
kernel_regularizer=l2( | |
params['RECURRENT_WEIGHT_DECAY']), | |
recurrent_regularizer=l2( | |
params['RECURRENT_WEIGHT_DECAY']), | |
bias_regularizer=l2( | |
params['RECURRENT_WEIGHT_DECAY']), | |
dropout=params['RECURRENT_INPUT_DROPOUT_P'], | |
recurrent_dropout=params[ | |
'RECURRENT_DROPOUT_P'], | |
kernel_initializer=params['INIT_FUNCTION'], | |
recurrent_initializer=params['INNER_INIT'], | |
return_sequences=True, | |
trainable=self.trainable), | |
name='src_bidirectional_encoder_' + params['ENCODER_RNN_TYPE'], | |
merge_mode='concat')(src_embedding) | |
trg_words = Input(name=self.ids_inputs[1], | |
batch_shape=tuple([None, params['MAX_INPUT_TEXT_LEN']]), dtype='int32') | |
trg_embedding = Embedding(params['OUTPUT_VOCABULARY_SIZE'], params['TARGET_TEXT_EMBEDDING_SIZE'], | |
name='target_word_embedding', | |
embeddings_regularizer=l2(params['WEIGHT_DECAY']), | |
embeddings_initializer=params['INIT_FUNCTION'], | |
trainable=self.trainable, | |
mask_zero=True)(trg_words) | |
trg_embedding = Regularize(trg_embedding, params, trainable=self.trainable, name='state') | |
trg_annotations = Bidirectional(eval(params['ENCODER_RNN_TYPE'])(params['ENCODER_HIDDEN_SIZE'], | |
kernel_regularizer=l2( | |
params['RECURRENT_WEIGHT_DECAY']), | |
recurrent_regularizer=l2( | |
params['RECURRENT_WEIGHT_DECAY']), | |
bias_regularizer=l2( | |
params['RECURRENT_WEIGHT_DECAY']), | |
dropout=params['RECURRENT_INPUT_DROPOUT_P'], | |
recurrent_dropout=params[ | |
'RECURRENT_DROPOUT_P'], | |
kernel_initializer=params['INIT_FUNCTION'], | |
recurrent_initializer=params['INNER_INIT'], | |
return_sequences=True, trainable=self.trainable), | |
name='bidirectional_encoder_' + params['ENCODER_RNN_TYPE'], | |
merge_mode='concat')(trg_embedding) | |
annotations = concatenate([src_annotations, trg_annotations], name='anot_seq_concat') | |
annotations = NonMasking()(annotations) | |
# apply attention over words at the sentence-level | |
annotations = attention_3d_block(annotations, params, 'sent') | |
out_activation=params.get('OUT_ACTIVATION', 'sigmoid') | |
qe_sent = Dense(1, activation=out_activation, name=self.ids_outputs[0])(annotations) | |
self.model = Model(inputs=[src_words, trg_words], | |
outputs=[qe_sent]) | |
def attention_3d_block(inputs, params, ext): | |
''' | |
simple attention: weights over time steps; as in https://github.com/philipperemy/keras-attention-mechanism | |
''' | |
# inputs.shape = (batch_size, time_steps, input_dim) | |
TIME_STEPS = K.int_shape(inputs)[1] | |
input_dim = K.int_shape(inputs)[2] | |
a = Permute((2, 1))(inputs) | |
a = Dense(TIME_STEPS, activation='softmax', name='soft_att' + ext)(a) | |
a = Lambda(lambda x: K.mean(x, axis=1), name='dim_reduction' + ext, output_shape=(TIME_STEPS,))(a) | |
a = RepeatVector(input_dim)(a) | |
a_probs = Permute((2, 1), name='attention_vec' + ext)(a) | |
output_attention_mul = multiply([inputs, a_probs], name='attention_mul' + ext) | |
sum = Lambda(reduce_sum, mask_aware_mean_output_shape) | |
output = sum(output_attention_mul) | |
return output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment