Skip to content

Instantly share code, notes, and snippets.

@C-EB
Created February 6, 2025 16:14
Show Gist options
  • Save C-EB/bdbed5e69d44a42140d8f278b4684608 to your computer and use it in GitHub Desktop.
Save C-EB/bdbed5e69d44a42140d8f278b4684608 to your computer and use it in GitHub Desktop.
Plot training and validation loss
import matplotlib.pyplot as plt
train_loss = [entry['loss'] for entry in trainer.state.log_history if 'loss' in entry]
epochs = [entry['epoch'] for entry in trainer.state.log_history if 'loss' in entry]
eval_loss = [entry['eval_loss'] for entry in trainer.state.log_history if 'eval_loss' in entry]
eval_epochs = [entry['epoch'] for entry in trainer.state.log_history if 'eval_loss' in entry]
plt.plot(epochs, train_loss, label="Training Loss")
plt.plot(eval_epochs, eval_loss, label="Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("Training and Validation Loss")
plt.legend()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment