Skip to content

Instantly share code, notes, and snippets.

@ChunML
Created April 29, 2019 09:07
Show Gist options
  • Save ChunML/2813571c4c3dbd28abec0f6bf74d0cbe to your computer and use it in GitHub Desktop.
Save ChunML/2813571c4c3dbd28abec0f6bf74d0cbe to your computer and use it in GitHub Desktop.
def call(self, query, value, mask=None):
# query has shape (batch, query_len, model_size)
# value has shape (batch, value_len, model_size)
heads = []
for i in range(self.h):
score = tf.matmul(self.wq[i](query), self.wk[i](value), transpose_b=True)
# Here we scale the score as described in the paper
score /= tf.math.sqrt(tf.dtypes.cast(self.key_size, tf.float32))
# score has shape (batch, query_len, value_len)
# mask must be broadcastable to (batch, query_len, value_len)
if mask is not None:
score *= mask
# asign masked positions to -1e9
# so that their values after softmax are zeros
score = tf.where(tf.equal(score, 0), tf.ones_like(score) * -1e9, score)
alignment = tf.nn.softmax(score, axis=2)
# alignment has shape (batch, query_len, value_len)
head = tf.matmul(alignment, self.wv[i](value))
# head has shape (batch, decoder_len, value_size)
heads.append(head)
# Concatenate all the attention heads
# so that the last dimension summed up to model_size
heads = tf.concat(heads, axis=2)
heads = self.wo(heads)
# heads has shape (batch, query_len, model_size)
return heads
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment