Last active
August 4, 2019 16:55
-
-
Save samarth-agrawal-86/f2b5bd2708cfb838a8acf0ad717fcbb2 to your computer and use it in GitHub Desktop.
Model Class defined in pytorch framework
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch.nn as nn | |
class SentimentLSTM(nn.Module): | |
""" | |
The RNN model that will be used to perform Sentiment analysis. | |
""" | |
def __init__(self, vocab_size, output_size, embedding_dim, hidden_dim, n_layers, drop_prob=0.5): | |
""" | |
Initialize the model by setting up the layers. | |
""" | |
super().__init__() | |
self.output_size = output_size | |
self.n_layers = n_layers | |
self.hidden_dim = hidden_dim | |
# embedding and LSTM layers | |
self.embedding = nn.Embedding(vocab_size, embedding_dim) | |
self.lstm = nn.LSTM(embedding_dim, hidden_dim, n_layers, | |
dropout=drop_prob, batch_first=True) | |
# dropout layer | |
self.dropout = nn.Dropout(0.3) | |
# linear and sigmoid layers | |
self.fc = nn.Linear(hidden_dim, output_size) | |
self.sig = nn.Sigmoid() | |
def forward(self, x, hidden): | |
""" | |
Perform a forward pass of our model on some input and hidden state. | |
""" | |
batch_size = x.size(0) | |
# embeddings and lstm_out | |
embeds = self.embedding(x) | |
lstm_out, hidden = self.lstm(embeds, hidden) | |
# stack up lstm outputs | |
lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim) | |
# dropout and fully-connected layer | |
out = self.dropout(lstm_out) | |
out = self.fc(out) | |
# sigmoid function | |
sig_out = self.sig(out) | |
# reshape to be batch_size first | |
sig_out = sig_out.view(batch_size, -1) | |
sig_out = sig_out[:, -1] # get last batch of labels | |
# return last sigmoid output and hidden state | |
return sig_out, hidden | |
def init_hidden(self, batch_size): | |
''' Initializes hidden state ''' | |
# Create two new tensors with sizes n_layers x batch_size x hidden_dim, | |
# initialized to zero, for hidden state and cell state of LSTM | |
weight = next(self.parameters()).data | |
if (train_on_gpu): | |
hidden = (weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().cuda(), | |
weight.new(self.n_layers, batch_size, self.hidden_dim).zero_().cuda()) | |
else: | |
hidden = (weight.new(self.n_layers, batch_size, self.hidden_dim).zero_(), | |
weight.new(self.n_layers, batch_size, self.hidden_dim).zero_()) | |
return hidden | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment