Skip to content

Instantly share code, notes, and snippets.

@shayaf84
Last active August 2, 2020 23:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shayaf84/4869ba8d0ab0d6a9edb616193ad185b2 to your computer and use it in GitHub Desktop.
Save shayaf84/4869ba8d0ab0d6a9edb616193ad185b2 to your computer and use it in GitHub Desktop.
Plot accuracy and loss
def plot_acc(history, ax = None, xlabel = 'Epoch #'):
history = history.history
history.update({'epoch':list(range(len(history['val_accuracy'])))})
history = pd.DataFrame.from_dict(history)
best_epoch = history.sort_values(by = 'val_accuracy', \
ascending = False).iloc[0]['epoch']
if not ax:
f, ax = plt.subplots(1,1)
sns.lineplot(x = 'epoch', y = 'val_accuracy', data = history \
label = 'Validation', ax = ax)
sns.lineplot(x = 'epoch', y = 'accuracy', data = history\
label = 'Training', ax = ax)
ax.axhline(0.5, linestyle = '--',color='red', label = 'Chance')
ax.axvline(x = best_epoch, linestyle = '--', color = 'green', \
label = 'Best Epoch')
ax.legend(loc = 1)
ax.set_ylim([0.4, 1])
ax.set_xlabel(xlabel)
ax.set_ylabel('Accuracy (Fraction)')
plt.show()
print("The highest validation accuracy was",history.sort_values\
(by = 'val_accuracy', ascending = False).iloc[0]['val_accuracy'])
print("The lowest validation accuracy was",history.sort_values\
(by = 'val_accuracy', ascending = True).iloc[0]['val_accuracy'])
def plot_loss(history, ax = None, xlabel = 'Epoch #'):
history = history.history
history.update({'epoch':list(range(len(history['val_loss'])))})
history = pd.DataFrame.from_dict(history)
best_epoch = history.sort_values(by = 'val_loss',\
ascending = True).iloc[0]['epoch']
if not ax:
f, ax = plt.subplots(1,1)
sns.lineplot(x = 'epoch', y = 'val_loss', data = history,\
label = 'Validation', ax = ax)
sns.lineplot(x = 'epoch', y = 'loss', data = history,\
label = 'Training', ax = ax)
ax.axvline(x = best_epoch, linestyle = '--', color = 'green',\
label = 'Best Epoch')
ax.legend(loc = 1)
ax.set_ylim([0.1, 1])
ax.set_xlabel(xlabel)
ax.set_ylabel('Loss (Fraction)')
plt.show()
print("The lowest validation loss was",history.sort_values\
(by = 'val_loss', ascending = True).iloc[0]['val_loss'])
print("The highest validation loss was",history.sort_values\
(by = 'val_loss', ascending = False).iloc[0]['val_loss'])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment