Skip to content

Instantly share code, notes, and snippets.

@cschell
Last active August 26, 2019 13:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cschell/5368647625b0786325949611c0339039 to your computer and use it in GitHub Desktop.
Save cschell/5368647625b0786325949611c0339039 to your computer and use it in GitHub Desktop.
losses = []
loss_function = nn.NLLLoss()
model = LeModel(len(vocab))
optimizer = optim.SGD(model.parameters(), lr=0.001)
for epoch in tqdm(range(1), leave=False):
total_loss = 0
for context, target in tqdm(data, leave=False):
model.zero_grad()
predictions = model(x_train)
loss = loss_function(predictions, y_train, dtype=torch.long)
loss.backward()
optimizer.step()
total_loss += loss.item()
print("%s: %s" % (epoch, loss.item()))
losses.append(total_loss)
torch.save(model.state_dict(), "model.torch")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment