Skip to content

Instantly share code, notes, and snippets.

@rueian
Created March 19, 2023 12:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rueian/839b6fb4a7cfa8a8b963e74d73ab33e6 to your computer and use it in GitHub Desktop.
Save rueian/839b6fb4a7cfa8a8b963e74d73ab33e6 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.optim as optim
class LSTMModel(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(LSTMModel, self).__init__()
self.hidden_dim = hidden_dim
self.lstm = nn.LSTM(input_dim, hidden_dim)
self.fc = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
h0 = torch.zeros(1, x.size(1), self.hidden_dim)
c0 = torch.zeros(1, x.size(1), self.hidden_dim)
out, (hn, cn) = self.lstm(x, (h0, c0))
out = self.fc(out)
return out
def target_transform(m: torch.Tensor):
a = torch.sum(m[:, :, 0:3], dim=-1, keepdim=True)
b = torch.sum(m[:, :, 1:4], dim=-1, keepdim=True)
c = torch.sum(m[:, :, 2:5], dim=-1, keepdim=True)
t = torch.cat([a, b, c], dim=-1)
for i in range(1, t.shape[0]):
t[i] += t[i-1]
return t
if __name__ == "__main__":
model = LSTMModel(input_dim=5, hidden_dim=10, output_dim=3)
optimizer = optim.AdamW(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
# train the model
for epoch in range(10000):
optimizer.zero_grad()
input = torch.randn(10, 1000, 5)
output = model(input)
loss = criterion(output, target_transform(input))
loss.backward()
optimizer.step()
if epoch % 100 == 0:
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, 10000, loss.item()))
model.eval()
with torch.no_grad():
input = torch.randn(10, 3, 5)
output = model(input)
print('====')
print(input)
print('====')
print(output)
print('====')
print(target_transform(input))
print('====')
print(criterion(output, target_transform(input)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment