Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic
Last active August 16, 2019 13:21
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NMZivkovic/496fcc53e453a485fad8aedde6a1a1e2 to your computer and use it in GitHub Desktop.
Save NMZivkovic/496fcc53e453a485fad8aedde6a1a1e2 to your computer and use it in GitHub Desktop.
class MultiHeadAttentionLayer(Layer):
def __init__(self, num_neurons, num_heads):
super(MultiHeadAttentionLayer, self).__init__()
self.num_heads = num_heads
self.num_neurons = num_neurons
self.depth = num_neurons // self.num_heads
self.attention_layer = ScaledDotProductAttentionLayer()
self.q_layer = Dense(num_neurons)
self.k_layer = Dense(num_neurons)
self.v_layer = Dense(num_neurons)
self.linear_layer = Dense(num_neurons)
def split(self, x, batch_size):
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, v, k, q, mask):
batch_size = tf.shape(q)[0]
# Run through linear layers
q = self.q_layer(q)
k = self.k_layer(k)
v = self.v_layer(v)
# Split the heads
q = self.split(q, batch_size)
k = self.split(k, batch_size)
v = self.split(v, batch_size)
# Run through attention
attention_output, weights = self.attention_layer.calculate_output_weights(q, k, v, mask)
# Prepare for the rest of processing
output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
concat_attention = tf.reshape(output, (batch_size, -1, self.num_neurons))
# Run through final linear layer
output = self.linear_layer(concat_attention)
return output, weights
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment