elmo_model = hub.Module("https://tfhub.dev/google/elmo/2", trainable=False)
def ElmoEmbedding(x):
return elmo_model(tf.squeeze(tf.cast(x, tf.string)),
signature="default", as_dict=True)["default"]
sequence_input = Input(shape=(1,), dtype=tf.string)
embedded_sequences = Lambda(ElmoEmbedding, output_shape=(1024,))(sequence_input)
...
preds = Dense(2, activation="softmax")(l_dense)
model = Model(sequence_input, preds)
model.compile(loss="binary_crossentropy",
optimizer="rmsprop,
metrics=["acc"])
model.fit()
matmul_qk = tf.matmul(query, key, transpose_b=True)
# scale matmul_qk
depth = tf.cast(tf.shape(key)[-1], tf.float32)
logits = matmul_qk / tf.math.sqrt(depth)
# add the mask to zero out padding tokens
if mask is not None:
logits += (mask * -1e9)
# softmax is normalized on the last axis (seq_len_k)
attention_weights = tf.nn.softmax(logits, axis=-1)
output = tf.matmul(attention_weights, value)
return output