Skip to content

Instantly share code, notes, and snippets.

@tlkh
Last active August 23, 2019 04:02
Show Gist options
  • Save tlkh/b7a41ae8baf2a29d738d0490d4bd67e9 to your computer and use it in GitHub Desktop.
Save tlkh/b7a41ae8baf2a29d738d0490d4bd67e9 to your computer and use it in GitHub Desktop.
elmo_model = hub.Module("https://tfhub.dev/google/elmo/2", trainable=False)

def ElmoEmbedding(x):
    return elmo_model(tf.squeeze(tf.cast(x, tf.string)),
                      signature="default", as_dict=True)["default"]
    
sequence_input = Input(shape=(1,), dtype=tf.string)
embedded_sequences = Lambda(ElmoEmbedding, output_shape=(1024,))(sequence_input)
...
preds = Dense(2, activation="softmax")(l_dense)

model = Model(sequence_input, preds)
model.compile(loss="binary_crossentropy",
              optimizer="rmsprop,
              metrics=["acc"])

model.fit()
matmul_qk = tf.matmul(query, key, transpose_b=True)

# scale matmul_qk
depth = tf.cast(tf.shape(key)[-1], tf.float32)
logits = matmul_qk / tf.math.sqrt(depth)

# add the mask to zero out padding tokens
if mask is not None:
    logits += (mask * -1e9)

# softmax is normalized on the last axis (seq_len_k)
attention_weights = tf.nn.softmax(logits, axis=-1)

output = tf.matmul(attention_weights, value)

return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment