Skip to content

Instantly share code, notes, and snippets.

@bryanlimy
Last active October 15, 2019 02:10
Show Gist options
  • Save bryanlimy/9a3a87652d64dfe339153d3b00c3adfb to your computer and use it in GitHub Desktop.
Save bryanlimy/9a3a87652d64dfe339153d3b00c3adfb to your computer and use it in GitHub Desktop.
def scaled_dot_product_attention(query, key, value, mask):
matmul_qk = tf.matmul(query, key, transpose_b=True)
depth = tf.cast(tf.shape(key)[-1], tf.float32)
logits = matmul_qk / tf.math.sqrt(depth)
# add the mask zero out padding tokens.
if mask is not None:
logits += (mask * -1e9)
attention_weights = tf.nn.softmax(logits, axis=-1)
return tf.matmul(attention_weights, value)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment