Skip to content

Instantly share code, notes, and snippets.

@ChunML
Created April 28, 2019 10:36
Show Gist options
  • Save ChunML/9316f450512b6b6bdb16cba3e1fb01a9 to your computer and use it in GitHub Desktop.
Save ChunML/9316f450512b6b6bdb16cba3e1fb01a9 to your computer and use it in GitHub Desktop.
def call(self, sequence, encoder_output):
# EMBEDDING AND POSITIONAL EMBEDDING
embed_out = []
for i in range(sequence.shape[1]):
embed = self.embedding(tf.expand_dims(sequence[:, i], axis=1))
embed_out.append(embed + pes[i, :])
embed_out = tf.concat(embed_out, axis=1)
bot_sub_in = embed_out
for i in range(self.num_layers):
# BOTTOM MULTIHEAD SUB LAYER
bot_sub_out = []
for j in range(bot_sub_in.shape[1]):
values = bot_sub_in[:, :j, :]
attention = self.attention_bot[i](
tf.expand_dims(bot_sub_in[:, j, :], axis=1), values)
bot_sub_out.append(attention)
bot_sub_out = tf.concat(bot_sub_out, axis=1)
bot_sub_out = bot_sub_in + bot_sub_out
bot_sub_out = self.attention_bot_norm[i](bot_sub_out)
# MIDDLE MULTIHEAD SUB LAYER
mid_sub_in = bot_sub_out
mid_sub_out = []
for j in range(mid_sub_in.shape[1]):
attention = self.attention_mid[i](
tf.expand_dims(mid_sub_in[:, j, :], axis=1), encoder_output)
mid_sub_out.append(attention)
mid_sub_out = tf.concat(mid_sub_out, axis=1)
mid_sub_out = mid_sub_out + mid_sub_in
mid_sub_out = self.attention_mid_norm[i](mid_sub_out)
# FFN
ffn_in = mid_sub_out
ffn_out = self.dense_2[i](self.dense_1[i](ffn_in))
ffn_out = ffn_out + ffn_in
ffn_out = self.ffn_norm[i](ffn_out)
bot_sub_in = ffn_out
logits = self.dense(ffn_out)
return logits
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment