Skip to content

Instantly share code, notes, and snippets.

@rian-dolphin
Created November 3, 2021 21:45
Show Gist options
  • Save rian-dolphin/8994c853464a1bbac2c331208d4834de to your computer and use it in GitHub Desktop.
Save rian-dolphin/8994c853464a1bbac2c331208d4834de to your computer and use it in GitHub Desktop.
class LSTM(nn.Module):
"""
input_size - will be 1 in this example since we have only 1 predictor (a sequence of previous values)
hidden_size - Can be chosen to dictate how much hidden "long term memory" the network will have
output_size - This will be equal to the prediciton_periods input to get_x_y_pairs
"""
def __init__(self, input_size, hidden_size, output_size):
super(LSTM, self).__init__()
self.hidden_size = hidden_size
self.lstm = nn.LSTM(input_size, hidden_size)
self.linear = nn.Linear(hidden_size, output_size)
def forward(self, x, hidden=None):
if hidden==None:
self.hidden = (torch.zeros(1,1,self.hidden_size),
torch.zeros(1,1,self.hidden_size))
else:
self.hidden = hidden
"""
inputs need to be in the right shape as defined in documentation
- https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
lstm_out - will contain the hidden states from all times in the sequence
self.hidden - will contain the current hidden state and cell state
"""
lstm_out, self.hidden = self.lstm(x.view(len(x),1,-1),
self.hidden)
predictions = self.linear(lstm_out.view(len(x), -1))
return predictions[-1], self.hidden
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment