Skip to content

Instantly share code, notes, and snippets.

@Deepayan137
Created April 13, 2018 09:11
Show Gist options
  • Save Deepayan137/d1d0161f8f528e677e214e3c2a03f55a to your computer and use it in GitHub Desktop.
Save Deepayan137/d1d0161f8f528e677e214e3c2a03f55a to your computer and use it in GitHub Desktop.
SeqGan model
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
class DiscNet():
def __init__(self, vocab_size, hidden_size, embedding_size, rnn_type, dropout=0.2):
super(DiscNet, self).__init__()
self.hidden_size = hidden_size
self.embedding = nn.Embedding(vocab_size, embedding_size)
self.rnn_type = rnn_type
if rnn_type == 'rnn':
rnn = nn.RNN
elif rnn_type == 'gru':
rnn = nn.GRU
self.rnn = rnn(embedding_size, hidden_size, batch_first=True)
self.rnn2hidden = nn.Linear(hidden_size, hidden_size)
self.dropout_linear = nn.Dropout(p=dropout)
self.hidden2out = nn.Linear(hidden_size, 1)
def forward(self, input_, hidden, length):
# input_sequence batch_size x max_seq_len
batch_size = input_.size(0)
sorted_lengths, sorted_idx = torch.sort(length, descending=True)
input_sequence = input_sequence[sorted_idx]
input_embedding = self.embedding(input_) #batch_size x seq_len
packed_input = rnn_utils.pack_padded_sequence(input_embedding, sorted_lengths.data.tolist(), batch_first=True)
_, hidden = self.encoder_rnn(packed_input)
hidden = hidden.squeeze()
out = self.rnn2hidden(hidden)
out = F.tanh(out)
out = self.dropout_linear(out)
out = self.hidden2out(out)
out = F.sigmoid(out)
return out
class GenNet():
def __init__(self, vocab_size, hidden_size, embedding_size, rnn_type, dropout=0.2):
super(GenNet, self).__init__()
self.hidden_size = hidden_size
self.embedding = nn.Embedding(vocab_size, embedding_size)
self.rnn_type = rnn_type
if rnn_type == 'rnn':
rnn = nn.RNN
elif rnn_type == 'gru':
rnn = nn.GRU
self.rnn = rnn(embedding_size, hidden_size, batch_first=True)
self.hidden2out = nn.Linear(hidden_size, vocab_size)
def forward(self, input_, hidden, length):
batch_size = input_.size(0)
sorted_lengths, sorted_idx = torch.sort(length, descending=True)
input_sequence = input_sequence[sorted_idx]
input_embedding = self.embedding(input_) #batch_size x seq_len
packed_input = rnn_utils.pack_padded_sequence(input_embedding, sorted_lengths.data.tolist(), batch_first=True)
out , hidden = self.encoder_rnn(packed_input)
out = self.hidden2out(out)
out = F.log_softmax(out)
return out, hidden
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment