Skip to content

Instantly share code, notes, and snippets.

@gautham20
Created June 6, 2020 22:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gautham20/8ae66cdcfbbc98bd887e0e57f2bca3c3 to your computer and use it in GitHub Desktop.
Save gautham20/8ae66cdcfbbc98bd887e0e57f2bca3c3 to your computer and use it in GitHub Desktop.
class DecoderCell(nn.Module):
def __init__(self, input_feature_len, hidden_size, dropout=0.2):
super().__init__()
self.decoder_rnn_cell = nn.GRUCell(
input_size=input_feature_len,
hidden_size=hidden_size,
)
self.out = nn.Linear(hidden_size, 1)
self.attention = False
self.dropout = nn.Dropout(dropout)
def forward(self, prev_hidden, y):
rnn_hidden = self.decoder_rnn_cell(y, prev_hidden)
output = self.out(rnn_hidden)
return output, self.dropout(rnn_hidden)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment