Skip to content

Instantly share code, notes, and snippets.

@edumunozsala
Last active September 25, 2020 18:19
Show Gist options
  • Save edumunozsala/79afe09b0fad3758ad99dae9adc4e6b4 to your computer and use it in GitHub Desktop.
Save edumunozsala/79afe09b0fad3758ad99dae9adc4e6b4 to your computer and use it in GitHub Desktop.
Model Definition for CLTG
import torch
from torch import nn
from torch.autograd import Variable
class RNNModel(nn.Module):
def __init__(self, vocab_size, embedding_size, hidden_dim, n_layers, drop_rate=0.2):
super(RNNModel, self).__init__()
# Defining some parameters
self.hidden_dim = hidden_dim
self.embedding_size = embedding_size
self.n_layers = n_layers
self.vocab_size = vocab_size
self.drop_rate = drop_rate
self.char2int = None
self.int2char = None
# Dropout layer
self.dropout = nn.Dropout(drop_rate)
# RNN Layer
self.rnn = nn.LSTM(embedding_size, hidden_dim, n_layers, dropout=drop_rate, batch_first = True)
# Fully connected layer
self.decoder = nn.Linear(hidden_dim, vocab_size)
def forward(self, x, state):
# shape: [batch_size, seq_len, embedding_size]
rnn_out, state = self.rnn(x, state)
#print('Out RNN shape: ', rnn_out.shape)
# rnn_out shape: [batch_size, seq_len, rnn_size]
# hidden shape: [num_layers, batch_size, rnn_size]
rnn_out = self.dropout(rnn_out)
# shape: [batch_size, seq_len, rnn_size]
# Stack up LSTM outputs using view
# you may need to use contiguous to reshape the output
rnn_out = rnn_out.contiguous().view(-1, self.hidden_dim)
logits = self.decoder(rnn_out)
# output shape: [seq_len * batch_size, vocab_size]
return logits, state
def init_state(self, device, batch_size=1):
"""
initialises rnn states.
"""
return (torch.zeros(self.n_layers, batch_size, self.hidden_dim).to(device),
torch.zeros(self.n_layers, batch_size, self.hidden_dim).to(device))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment