Skip to content

Instantly share code, notes, and snippets.

@ChunML
Created April 30, 2019 03:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ChunML/0621f113d8e0b33f2899328acfc8bab9 to your computer and use it in GitHub Desktop.
Save ChunML/0621f113d8e0b33f2899328acfc8bab9 to your computer and use it in GitHub Desktop.
def call(self, sequence, encoder_output, padding_mask):
# EMBEDDING AND POSITIONAL EMBEDDING
embed_out = embedding(sequence)
embed_out += pes[:sequence.shape[1], :]
bot_sub_in = embed_out
for i in range(self.num_layers):
# BOTTOM MULTIHEAD SUB LAYER
look_left_only_mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
bot_sub_out = self.attention_bot[i](bot_sub_in, bot_sub_in, look_left_only_mask)
bot_sub_out = bot_sub_in + bot_sub_out
bot_sub_out = self.attention_bot_norm[i](bot_sub_out)
# MIDDLE MULTIHEAD SUB LAYER
mid_sub_in = bot_sub_out
mid_sub_out = self.attention_mid[i](mid_sub_in, encoder_output, padding_mask)
mid_sub_out = mid_sub_out + mid_sub_in
mid_sub_out = self.attention_mid_norm[i](mid_sub_out)
# FFN
ffn_in = mid_sub_out
ffn_out = self.dense_2[i](self.dense_1[i](ffn_in))
ffn_out = ffn_out + ffn_in
ffn_out = self.ffn_norm[i](ffn_out)
bot_sub_in = ffn_out
logits = self.dense(ffn_out)
return logits
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment