Skip to content

Instantly share code, notes, and snippets.

@edumunozsala
Created October 26, 2020 18:45
Show Gist options
  • Save edumunozsala/72d25ca4ef1d5fde7eb4ebbd5d51792f to your computer and use it in GitHub Desktop.
Save edumunozsala/72d25ca4ef1d5fde7eb4ebbd5d51792f to your computer and use it in GitHub Desktop.
Scaled dot product attention for Transformer
def scaled_dot_product_attention(queries, keys, values, mask):
# Calculate the dot product, QK_transpose
product = tf.matmul(queries, keys, transpose_b=True)
# Get the scale factor
keys_dim = tf.cast(tf.shape(keys)[-1], tf.float32)
# Apply the scale factor to the dot product
scaled_product = product / tf.math.sqrt(keys_dim)
# Apply masking when it is requiered
if mask is not None:
scaled_product += (mask * -1e9)
# dot product with Values
attention = tf.matmul(tf.nn.softmax(scaled_product, axis=-1), values)
return attention
@edumunozsala
Copy link
Author

tf is a naming convention for TensorFlow, it is usually imported in Python:

import tensorflow as tf

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment