Created
May 22, 2022 07:44
-
-
Save Tony363/de201c2899db7f5522b353b725beebf0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
import torch.nn as nn | |
class LSTMModel(nn.Module): | |
def __init__(self, input_dim, hidden_dim, layer_dim, output_dim): | |
super(LSTMModel, self).__init__() | |
# Hidden dimensions | |
self.hidden_dim = hidden_dim | |
# Number of hidden layers | |
self.layer_dim = layer_dim | |
# Building your LSTM | |
# batch_first=True causes input/output tensors to be of shape | |
# (batch_dim, seq_dim, feature_dim) | |
self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first=True) | |
# ReLu layer | |
self.relu = nn.ReLU() | |
# flatten layer | |
self.flatten = nn.Flatten() | |
# sigmoid layer | |
self.sig = nn.Sigmoid() | |
def forward(self, x): | |
# Initialize hidden state with zeros | |
####################### | |
# USE GPU FOR MODEL # | |
####################### | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
h0 = torch.zeros(self.layer_dim, x.size( | |
0), self.hidden_dim).requires_grad_().to(device) | |
# Initialize cell state | |
c0 = torch.zeros(self.layer_dim, x.size( | |
0), self.hidden_dim).requires_grad_().to(device) | |
# One time step | |
out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach())) | |
out = self.relu(out) | |
out = self.flatten(out) | |
out = self.sig(out) | |
return out |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment