Skip to content

Instantly share code, notes, and snippets.

@gautham20
Created June 6, 2020 20:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gautham20/fdf3ed1526e1f243216647af115d5aed to your computer and use it in GitHub Desktop.
Save gautham20/fdf3ed1526e1f243216647af115d5aed to your computer and use it in GitHub Desktop.
class RNNEncoder(nn.Module):
def __init__(self, rnn_num_layers=1, input_feature_len=1, sequence_len=168, hidden_size=100, bidirectional=False, device='cpu', rnn_dropout=0.2):
super().__init__()
self.sequence_len = sequence_len
self.hidden_size = hidden_size
self.input_feature_len = input_feature_len
self.num_layers = rnn_num_layers
self.rnn_directions = 2 if bidirectional else 1
self.gru = nn.GRU(
num_layers=rnn_num_layers,
input_size=input_feature_len,
hidden_size=hidden_size,
batch_first=True,
bidirectional=bidirectional,
dropout=rnn_dropout
)
self.device = device
def forward(self, input_seq):
ht = torch.zeros(self.num_layers * self.rnn_directions, input_seq.size(0) , self.hidden_size, device=self.device)
if input_seq.ndim < 3:
input_seq.unsqueeze_(2)
gru_out, hidden = self.gru(input_seq, ht)
if self.rnn_directions * self.num_layers > 1:
num_layers = self.rnn_directions * self.num_layers
if self.rnn_directions > 1:
gru_out = gru_out.view(input_seq.size(0), self.sequence_len, num_layers, self.hidden_size)
gru_out = torch.sum(gru_out, axis=2)
hidden = hidden.permute(1, 0, 2).reshape(input_seq.size(0), -1)
else:
hidden.squeeze_(0)
return gru_out, hidden
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment