Skip to content

Instantly share code, notes, and snippets.

@wdsrocha
Created December 22, 2018 09:14
Show Gist options
  • Save wdsrocha/1dfe7501e20d37225491257bd54c768d to your computer and use it in GitHub Desktop.
Save wdsrocha/1dfe7501e20d37225491257bd54c768d to your computer and use it in GitHub Desktop.
Plot loss and accuracy from Keras model history.
# Extracted from https://www.kaggle.com/danbrice/keras-plot-history-full-report-and-grid-search
import matplotlib.pyplot as plt
def plot_history(history):
loss_list = [s for s in history.history.keys() if 'loss' in s and 'val' not in s]
val_loss_list = [s for s in history.history.keys() if 'loss' in s and 'val' in s]
acc_list = [s for s in history.history.keys() if 'acc' in s and 'val' not in s]
val_acc_list = [s for s in history.history.keys() if 'acc' in s and 'val' in s]
if len(loss_list) == 0:
print('Loss is missing in history')
return
## As loss always exists
epochs = range(1,len(history.history[loss_list[0]]) + 1)
## Loss
plt.figure(1)
for l in loss_list:
plt.plot(epochs, history.history[l], 'b', label='Training loss (' + str(str(format(history.history[l][-1],'.5f'))+')'))
for l in val_loss_list:
plt.plot(epochs, history.history[l], 'g', label='Validation loss (' + str(str(format(history.history[l][-1],'.5f'))+')'))
plt.title('Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
## Accuracy
plt.figure(2)
for l in acc_list:
plt.plot(epochs, history.history[l], 'b', label='Training accuracy (' + str(format(history.history[l][-1],'.5f'))+')')
for l in val_acc_list:
plt.plot(epochs, history.history[l], 'g', label='Validation accuracy (' + str(format(history.history[l][-1],'.5f'))+')')
plt.title('Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment