Skip to content

Instantly share code, notes, and snippets.

@qdwang
Created March 9, 2023 07:46
Show Gist options
  • Save qdwang/b6037c9117195cc07c4582fdd6d126a8 to your computer and use it in GitHub Desktop.
Save qdwang/b6037c9117195cc07c4582fdd6d126a8 to your computer and use it in GitHub Desktop.
pytorch_training_problem
import torch
from torch import nn, Tensor
from torch.nn.functional import *
class MyModel(torch.nn.Module):
def __init__(self) -> None:
super(MyModel, self).__init__()
self.encoder_input_layer = nn.Linear(3, 512)
self.decoder_input_layer = nn.Linear(1, 512)
self.output_layer = nn.Linear(512, 1)
self.transformer = torch.nn.Transformer()
def forward(self, src, tgt):
src = self.encoder_input_layer(src)
tgt = self.decoder_input_layer(tgt)
output = self.transformer(src=src, tgt=tgt)
output = self.output_layer(output)
return output
model = MyModel()
loss_fn = torch.nn.HuberLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
model.train()
enc_seq = torch.tensor([[1., 2., 3.], [55., 56., 57.]])
dec_seq = torch.tensor([[3.], [57.]])
goal = torch.tensor([[4.], [58.]])
for i in range(100):
optimizer.zero_grad()
outputs = model(enc_seq, dec_seq)
loss = loss_fn(outputs, goal)
loss.backward()
optimizer.step()
print(f'step {i} loss {loss.item()}')
# loss is always around 26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment