Skip to content

Instantly share code, notes, and snippets.

@prateekjoshi565
Created July 18, 2020 11:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save prateekjoshi565/8e7a15125ac8f2402577aa1935000258 to your computer and use it in GitHub Desktop.
Save prateekjoshi565/8e7a15125ac8f2402577aa1935000258 to your computer and use it in GitHub Desktop.
# set initial loss to infinite
best_valid_loss = float('inf')
# empty lists to store training and validation loss of each epoch
train_losses=[]
valid_losses=[]
#for each epoch
for epoch in range(epochs):
print('\n Epoch {:} / {:}'.format(epoch + 1, epochs))
#train model
train_loss, _ = train()
#evaluate model
valid_loss, _ = evaluate()
#save the best model
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'saved_weights.pt')
# append training and validation loss
train_losses.append(train_loss)
valid_losses.append(valid_loss)
print(f'\nTraining Loss: {train_loss:.3f}')
print(f'Validation Loss: {valid_loss:.3f}')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment