Last active
September 25, 2020 18:19
-
-
Save edumunozsala/79afe09b0fad3758ad99dae9adc4e6b4 to your computer and use it in GitHub Desktop.
Model Definition for CLTG
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torch import nn | |
from torch.autograd import Variable | |
class RNNModel(nn.Module): | |
def __init__(self, vocab_size, embedding_size, hidden_dim, n_layers, drop_rate=0.2): | |
super(RNNModel, self).__init__() | |
# Defining some parameters | |
self.hidden_dim = hidden_dim | |
self.embedding_size = embedding_size | |
self.n_layers = n_layers | |
self.vocab_size = vocab_size | |
self.drop_rate = drop_rate | |
self.char2int = None | |
self.int2char = None | |
# Dropout layer | |
self.dropout = nn.Dropout(drop_rate) | |
# RNN Layer | |
self.rnn = nn.LSTM(embedding_size, hidden_dim, n_layers, dropout=drop_rate, batch_first = True) | |
# Fully connected layer | |
self.decoder = nn.Linear(hidden_dim, vocab_size) | |
def forward(self, x, state): | |
# shape: [batch_size, seq_len, embedding_size] | |
rnn_out, state = self.rnn(x, state) | |
#print('Out RNN shape: ', rnn_out.shape) | |
# rnn_out shape: [batch_size, seq_len, rnn_size] | |
# hidden shape: [num_layers, batch_size, rnn_size] | |
rnn_out = self.dropout(rnn_out) | |
# shape: [batch_size, seq_len, rnn_size] | |
# Stack up LSTM outputs using view | |
# you may need to use contiguous to reshape the output | |
rnn_out = rnn_out.contiguous().view(-1, self.hidden_dim) | |
logits = self.decoder(rnn_out) | |
# output shape: [seq_len * batch_size, vocab_size] | |
return logits, state | |
def init_state(self, device, batch_size=1): | |
""" | |
initialises rnn states. | |
""" | |
return (torch.zeros(self.n_layers, batch_size, self.hidden_dim).to(device), | |
torch.zeros(self.n_layers, batch_size, self.hidden_dim).to(device)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment