Skip to content

Instantly share code, notes, and snippets.

@aravindpai
Created January 27, 2020 13:53
Show Gist options
  • Save aravindpai/80946d83f3e8c5e2b650f8862e6a239e to your computer and use it in GitHub Desktop.
Save aravindpai/80946d83f3e8c5e2b650f8862e6a239e to your computer and use it in GitHub Desktop.
N_EPOCHS = 5
best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):
#train the model
train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
#evaluate the model
valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)
#save the best model
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'saved_weights.pt')
print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
print(f'\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment