Last active
March 23, 2021 16:21
-
-
Save bkaankuguoglu/afabeb96d32b9485ba52dca83274c573 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class RNNModel(nn.Module): | |
def __init__(self, input_dim, hidden_dim, layer_dim, output_dim, dropout_prob): | |
super(RNNModel, self).__init__() | |
# Defining the number of layers and the nodes in each layer | |
self.hidden_dim = hidden_dim | |
self.layer_dim = layer_dim | |
# RNN layers | |
self.rnn = nn.RNN( | |
input_dim, hidden_dim, layer_dim, batch_first=True, dropout=dropout_prob | |
) | |
# Fully connected layer | |
self.fc = nn.Linear(hidden_dim, output_dim) | |
def forward(self, x): | |
# Initializing hidden state for first input with zeros | |
h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_() | |
# Forward propagation by passing in the input and hidden state into the model | |
out, h0 = self.rnn(x, h0.detach()) | |
# Reshaping the outputs in the shape of (batch_size, seq_length, hidden_size) | |
# so that it can fit into the fully connected layer | |
out = out[:, -1, :] | |
# Convert the final state to our desired output shape (batch_size, output_dim) | |
out = self.fc(out) | |
return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment