Last active
August 2, 2020 23:39
-
-
Save shayaf84/4869ba8d0ab0d6a9edb616193ad185b2 to your computer and use it in GitHub Desktop.
Plot accuracy and loss
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def plot_acc(history, ax = None, xlabel = 'Epoch #'): | |
history = history.history | |
history.update({'epoch':list(range(len(history['val_accuracy'])))}) | |
history = pd.DataFrame.from_dict(history) | |
best_epoch = history.sort_values(by = 'val_accuracy', \ | |
ascending = False).iloc[0]['epoch'] | |
if not ax: | |
f, ax = plt.subplots(1,1) | |
sns.lineplot(x = 'epoch', y = 'val_accuracy', data = history \ | |
label = 'Validation', ax = ax) | |
sns.lineplot(x = 'epoch', y = 'accuracy', data = history\ | |
label = 'Training', ax = ax) | |
ax.axhline(0.5, linestyle = '--',color='red', label = 'Chance') | |
ax.axvline(x = best_epoch, linestyle = '--', color = 'green', \ | |
label = 'Best Epoch') | |
ax.legend(loc = 1) | |
ax.set_ylim([0.4, 1]) | |
ax.set_xlabel(xlabel) | |
ax.set_ylabel('Accuracy (Fraction)') | |
plt.show() | |
print("The highest validation accuracy was",history.sort_values\ | |
(by = 'val_accuracy', ascending = False).iloc[0]['val_accuracy']) | |
print("The lowest validation accuracy was",history.sort_values\ | |
(by = 'val_accuracy', ascending = True).iloc[0]['val_accuracy']) | |
def plot_loss(history, ax = None, xlabel = 'Epoch #'): | |
history = history.history | |
history.update({'epoch':list(range(len(history['val_loss'])))}) | |
history = pd.DataFrame.from_dict(history) | |
best_epoch = history.sort_values(by = 'val_loss',\ | |
ascending = True).iloc[0]['epoch'] | |
if not ax: | |
f, ax = plt.subplots(1,1) | |
sns.lineplot(x = 'epoch', y = 'val_loss', data = history,\ | |
label = 'Validation', ax = ax) | |
sns.lineplot(x = 'epoch', y = 'loss', data = history,\ | |
label = 'Training', ax = ax) | |
ax.axvline(x = best_epoch, linestyle = '--', color = 'green',\ | |
label = 'Best Epoch') | |
ax.legend(loc = 1) | |
ax.set_ylim([0.1, 1]) | |
ax.set_xlabel(xlabel) | |
ax.set_ylabel('Loss (Fraction)') | |
plt.show() | |
print("The lowest validation loss was",history.sort_values\ | |
(by = 'val_loss', ascending = True).iloc[0]['val_loss']) | |
print("The highest validation loss was",history.sort_values\ | |
(by = 'val_loss', ascending = False).iloc[0]['val_loss']) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment