Last active
June 28, 2022 11:11
-
-
Save charlieoneill11/c18fda905a03508fd0626f0cf2a2775b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class LSTM(nn.Module): | |
def __init__(self, hidden_layers=64): | |
super(LSTM, self).__init__() | |
self.hidden_layers = hidden_layers | |
# lstm1, lstm2, linear are all layers in the network | |
self.lstm1 = nn.LSTMCell(1, self.hidden_layers) | |
self.lstm2 = nn.LSTMCell(self.hidden_layers, self.hidden_layers) | |
self.linear = nn.Linear(self.hidden_layers, 1) | |
def forward(self, y, future_preds=0): | |
outputs, num_samples = [], y.size(0) | |
h_t = torch.zeros(n_samples, self.hidden_layers, dtype=torch.float32) | |
c_t = torch.zeros(n_samples, self.hidden_layers, dtype=torch.float32) | |
h_t2 = torch.zeros(n_samples, self.hidden_layers, dtype=torch.float32) | |
c_t2 = torch.zeros(n_samples, self.hidden_layers, dtype=torch.float32) | |
for time_step in y.split(1, dim=1): | |
# N, 1 | |
h_t, c_t = self.lstm1(input_t, (h_t, c_t)) # initial hidden and cell states | |
h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2)) # new hidden and cell states | |
output = self.linear(h_t2) # output from the last FC layer | |
outputs.append(output) | |
for i in range(future_preds): | |
# this only generates future predictions if we pass in future_preds>0 | |
# mirrors the code above, using last output/prediction as input | |
h_t, c_t = self.lstm1(output, (h_t, c_t)) | |
h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2)) | |
output = self.linear(h_t2) | |
outputs.append(output) | |
# transform list to tensor | |
outputs = torch.cat(outputs, dim=1) | |
return outputs |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The variable
n_samples
(line 12-15) should benum_samples
, right? And shouldn'ttime_step
beinput_t
instead?