Created
August 20, 2020 03:45
-
-
Save nahidalam/9d76be7c3f7e5a5f418fc87370afd8c3 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Encoder(tf.keras.Model): | |
def __init__(self, vocab_size, embedding_dim, enc_units, batch_sz): | |
''' | |
vocab_size: number of unique words | |
embedding_dim: dimension of your embedding output | |
enc_units: how many units of RNN cell | |
batch_sz: batch of data passed to the training in each epoch | |
''' | |
super(Encoder, self).__init__() | |
self.batch_sz = batch_sz | |
self.enc_units = enc_units | |
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim) | |
self.gru = tf.keras.layers.GRU(self.enc_units, | |
return_sequences=True, | |
return_state=True, | |
recurrent_initializer='glorot_uniform') | |
def call(self, x, hidden): | |
x = self.embedding(x) | |
output, state = self.gru(x, initial_state = hidden) | |
return output, state | |
def initialize_hidden_state(self): | |
return tf.zeros((self.batch_sz, self.enc_units)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment