Skip to content

Instantly share code, notes, and snippets.

@NMZivkovic
Created August 30, 2019 13:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save NMZivkovic/63f0d0392255a8a82eff5d70c072ccfc to your computer and use it in GitHub Desktop.
Save NMZivkovic/63f0d0392255a8a82eff5d70c072ccfc to your computer and use it in GitHub Desktop.
class Transformer(Model):
def __init__(self, num_layers, num_neurons, num_hidden_neurons, num_heads, input_vocabular_size, target_vocabular_size):
super(Transformer, self).__init__()
self.encoder = Encoder(num_neurons, num_hidden_neurons, num_heads, input_vocabular_size, num_layers)
self.decoder = Decoder(num_neurons, num_hidden_neurons, num_heads, target_vocabular_size, num_layers)
self.linear_layer = Dense(target_vocabular_size)
def call(self, transformer_input, tar, training, encoder_padding_mask, look_ahead_mask, decoder_padding_mask):
encoder_output = self.encoder(transformer_input, training, encoder_padding_mask)
decoder_output, attention_weights = self.decoder(tar, encoder_output, training, look_ahead_mask, decoder_padding_mask)
output = self.linear_layer(decoder_output)
return output, attention_weights
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment