Skip to content

Instantly share code, notes, and snippets.

@Shivam-316
Created November 17, 2020 16:41
Show Gist options
  • Save Shivam-316/f463d11dd5c4128f8d78736c11bac04e to your computer and use it in GitHub Desktop.
Save Shivam-316/f463d11dd5c4128f8d78736c11bac04e to your computer and use it in GitHub Desktop.
class Encoder(keras.Model):
def __init__(self,vocab_size=10000,emb_dim=128,units=256,batch_size=64):
super(Encoder,self).__init__()
self.units = units
self.batch = batch_size
self.emb_layer = keras.layers.Embedding(vocab_size,emb_dim)
self.lstm = keras.layers.LSTM(self.units,return_sequences=True,return_state=True)
def call(self,x,states):
emb=self.emb_layer(x)
output,hidden,carry=self.lstm(emb,initial_state=states)
return output,hidden,carry
def init_hidden_state(self):
return tf.zeros((self.batch,self.units)),tf.zeros((self.batch,self.units))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment