Skip to content

Instantly share code, notes, and snippets.

@edumunozsala
Created October 26, 2020 18:48
Show Gist options
  • Save edumunozsala/858116277f980197b95d1a1c9a1f4998 to your computer and use it in GitHub Desktop.
Save edumunozsala/858116277f980197b95d1a1c9a1f4998 to your computer and use it in GitHub Desktop.
Multi-head attention for Transformer
class MultiHeadAttention(layers.Layer):
def __init__(self, n_heads):
super(MultiHeadAttention, self).__init__()
self.n_heads = n_heads
def build(self, input_shape):
self.d_model = input_shape[-1]
assert self.d_model % self.n_heads == 0
# Calculate the dimension of every head or projection
self.d_head = self.d_model // self.n_heads
# Set the weight matrices for Q, K and V
self.query_lin = layers.Dense(units=self.d_model)
self.key_lin = layers.Dense(units=self.d_model)
self.value_lin = layers.Dense(units=self.d_model)
# Set the weight matrix for the output of the multi-head attention W0
self.final_lin = layers.Dense(units=self.d_model)
def split_proj(self, inputs, batch_size): # inputs: (batch_size, seq_length, d_model)
# Set the dimension of the projections
shape = (batch_size,
-1,
self.n_heads,
self.d_head)
# Split the input vectors
splited_inputs = tf.reshape(inputs, shape=shape) # (batch_size, seq_length, nb_proj, d_proj)
return tf.transpose(splited_inputs, perm=[0, 2, 1, 3]) # (batch_size, nb_proj, seq_length, d_proj)
def call(self, queries, keys, values, mask):
# Get the batch size
batch_size = tf.shape(queries)[0]
# Set the Query, Key and Value matrices
queries = self.query_lin(queries)
keys = self.key_lin(keys)
values = self.value_lin(values)
# Split Q, K y V between the heads or projections
queries = self.split_proj(queries, batch_size)
keys = self.split_proj(keys, batch_size)
values = self.split_proj(values, batch_size)
# Apply the scaled dot product
attention = scaled_dot_product_attention(queries, keys, values, mask)
# Get the attention scores
attention = tf.transpose(attention, perm=[0, 2, 1, 3])
# Concat the h heads or projections
concat_attention = tf.reshape(attention,
shape=(batch_size, -1, self.d_model))
# Apply W0 to get the output of the multi-head attention
outputs = self.final_lin(concat_attention)
return outputs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment