Created
June 7, 2020 10:30
-
-
Save jcrousse/79ed2feae1b6f8a8256b1d41a4e2854a to your computer and use it in GitHub Desktop.
sequence of sequence model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_learned_scores(**kwargs): | |
""" | |
scores each sentence, then multiply by score before next sequence layer. | |
:Keyword Arguments: | |
* sent_len (int) Sentence length | |
* embedding_size (int) word embedding length | |
* seq_len (int) length of overall sequence, equal to number of sentences x number of words per sentence | |
* pre_embedded (bool) True if input is already vectors of word embeddings, false if tokens to be embedded | |
* concat_outputs (bool) True for a model with two similar outputs (2 level sequence model), False for | |
a single output attention model (weighted average of sentences) | |
:param : (int) | |
""" | |
sent_len = kwargs.get('sent_len') | |
embed_size = kwargs.get('embedding_size') | |
seq_len = kwargs.get("seq_len") | |
pre_embedded = kwargs.get("pre_embedded", False) | |
assert seq_len % sent_len == 0, "sequence length must be a multiple of sentence length" | |
sent_per_obs = seq_len // sent_len | |
concat_outputs = kwargs.get("concat_outputs", False) | |
lstm_units_1 = kwargs.get('lstm_units_1', 16) | |
lstm_units_2 = kwargs.get('lstm_cells', 16) | |
if pre_embedded: | |
inputs = tf.keras.layers.Input(shape=(None, ), name="input") | |
embedded = tf.reshape(inputs, (-1, 1200, 768)) | |
else: | |
inputs = tf.keras.layers.Input(shape=(None,), name="input") | |
embedded = tf.keras.layers.Embedding(kwargs.get('vocab_size'), embed_size)(inputs) | |
reshaped = tf.reshape(embedded, (-1, sent_len, embed_size)) | |
lstm_level1 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(lstm_units_1))(reshaped) | |
x = tf.keras.layers.Dense(1, activation=None)(lstm_level1) | |
logits = tf.reshape(x, (-1, sent_per_obs)) | |
score = tf.keras.layers.Softmax(name="score")(logits) | |
weighted = tf.multiply(lstm_level1, tf.reshape(score, (-1, 1))) | |
reshaped_level2 = tf.reshape(weighted, (-1, sent_per_obs, lstm_units_1*2)) | |
w_average = tf.keras.layers.GlobalAveragePooling1D()(reshaped_level2) | |
lstm_level2 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(lstm_units_2))(reshaped_level2) | |
if concat_outputs: | |
classifier = tf.keras.layers.Dense(1)(lstm_level2) | |
classifier2 = tf.keras.layers.Dense(1)(w_average) | |
outputs = tf.keras.layers.concatenate([classifier, classifier2], name="output") | |
else: | |
classifier = tf.keras.layers.Dense(1, name="output")(lstm_level2) | |
classifier2 = tf.keras.layers.Dense(1, name="output_2")(w_average) | |
outputs = [classifier, classifier2] | |
model = tf.keras.Model(inputs=inputs, outputs=outputs) | |
return model |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment