Skip to content

Instantly share code, notes, and snippets.

@ashunigion
Created June 13, 2019 02:50
Show Gist options
  • Save ashunigion/27c40943f9b22a8acb639d5670cbb477 to your computer and use it in GitHub Desktop.
Save ashunigion/27c40943f9b22a8acb639d5670cbb477 to your computer and use it in GitHub Desktop.
RNN model architecture for sentiment-classification
import torch.nn as nn
class SentimentRNN(nn.Module):
"""
The RNN model that will be used to perform Sentiment analysis.
"""
def __init__(self, vocab_size, output_size, embedding_dim, hidden_dim, n_layers, drop_prob=0.5):
"""
Initialize the model by setting up the layers.
"""
super(SentimentRNN, self).__init__()
self.output_size = output_size
self.n_layers = n_layers
self.hidden_dim = hidden_dim
# define all layers
#embedding
#LSTM
#fully_connected
self.embedding = nn.Embedding(vocab_size,embedding_dim)
self.lstm = nn.LSTM(embedding_dim,hidden_dim,n_layers,
dropout=drop_prob, batch_first = True)
self.FC = nn.Linear(hidden_dim, output_size)
self.sig = nn.Sigmoid()
def forward(self, x, hidden):
"""
Perform a forward pass of our model on some input and hidden state.
"""
batch_size = x.size(0)
x = x.long()
embeds = self.embedding(x)
lstm_out, hidden = self.lstm(embeds, hidden)
#stack_up lstm outputs
lstm_out = lstm_out.contiguous().view(-1, self.hidden_dim)
out = self.FC(lstm_out)
sig_out = self.sig(out)
sig_out = sig_out.view(batch_size, -1)
sig_out = sig_out[:, -1]
# return last sigmoid output and hidden state
return sig_out, hidden
def init_hidden(self, batch_size):
''' Initializes hidden state '''
# Create two new tensors with sizes n_layers x batch_size x hidden_dim,
# initialized to zero, for hidden state and cell state of LSTM
weight = next(self.parameters()).data
if (train_on_gpu):
hidden = (weight.new(self.n_layers,batch_size,self.hidden_dim).zero_().cuda(),
weight.new(self.n_layers,batch_size,self.hidden_dim).zero_().cuda())
else:
hidden = (weight.new(self.n_layers,batch_size,self.hidden_dim).zero_(),
weight.new(self.n_layers,batch_size,self.hidden_dim).zero_())
return hidden
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment