Skip to content

Instantly share code, notes, and snippets.

@gautham20
Created June 8, 2020 22:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gautham20/fa59501aca91d78bd712bafc4ccca54a to your computer and use it in GitHub Desktop.
Save gautham20/fa59501aca91d78bd712bafc4ccca54a to your computer and use it in GitHub Desktop.
class RNNEncoder(nn.Module):
def __init__(self, rnn_num_layers=1, input_feature_len=1, sequence_len=168, hidden_size=100, bidirectional=False, device='cpu', rnn_dropout=0.2):
super().__init__()
self.sequence_len = sequence_len
self.hidden_size = hidden_size
self.input_feature_len = input_feature_len
self.num_layers = rnn_num_layers
self.rnn_directions = 2 if bidirectional else 1
self.gru = nn.GRU(
num_layers=rnn_num_layers,
input_size=input_feature_len,
hidden_size=hidden_size,
batch_first=True,
bidirectional=bidirectional,
dropout=rnn_dropout
)
self.device = device
def forward(self, input_seq):
ht = torch.zeros(self.num_layers * self.rnn_directions, input_seq.size(0), self.hidden_size, device=self.device)
if input_seq.ndim < 3:
input_seq.unsqueeze_(2)
gru_out, hidden = self.gru(input_seq, ht)
print(gru_out.shape)
print(hidden.shape)
if self.rnn_directions * self.num_layers > 1:
num_layers = self.rnn_directions * self.num_layers
if self.rnn_directions > 1:
gru_out = gru_out.view(input_seq.size(0), self.sequence_len, self.rnn_directions, self.hidden_size)
gru_out = torch.sum(gru_out, axis=2)
hidden = hidden.view(self.num_layers, self.rnn_directions, input_seq.size(0), self.hidden_size)
if self.num_layers > 0:
hidden = hidden[-1]
else:
hidden = hidden.squeeze(0)
hidden = hidden.sum(axis=0)
else:
hidden.squeeze_(0)
return gru_out, hidden
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment