Skip to content

Instantly share code, notes, and snippets.

@isauravmanitripathi
Created October 10, 2022 15:56
Show Gist options
  • Save isauravmanitripathi/5b11d92675022c5d89cce76b0567ee41 to your computer and use it in GitHub Desktop.
Save isauravmanitripathi/5b11d92675022c5d89cce76b0567ee41 to your computer and use it in GitHub Desktop.
from multihead_attention import MultiHeadAttention
from encoder import AddNormalization, FeedForward
class DecoderLayer(Layer):
def __init__(self, h, d_k, d_v, d_model, d_ff, rate, **kwargs):
super(DecoderLayer, self).__init__(**kwargs)
self.multihead_attention1 = MultiHeadAttention(h, d_k, d_v, d_model)
self.dropout1 = Dropout(rate)
self.add_norm1 = AddNormalization()
self.multihead_attention2 = MultiHeadAttention(h, d_k, d_v, d_model)
self.dropout2 = Dropout(rate)
self.add_norm2 = AddNormalization()
self.feed_forward = FeedForward(d_ff, d_model)
self.dropout3 = Dropout(rate)
self.add_norm3 = AddNormalization()
...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment