Skip to content

Instantly share code, notes, and snippets.

@whyboris
Last active May 12, 2021 02:47
Show Gist options
  • Star 7 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save whyboris/91ee793ddc92cf1e824978cf31bb790c to your computer and use it in GitHub Desktop.
Save whyboris/91ee793ddc92cf1e824978cf31bb790c to your computer and use it in GitHub Desktop.
Keras Loss & Accuracy Plot Helper Function
import matplotlib.pyplot as plt
# Plot model history more easily
# when plotting, smooth out the points by some factor (0.5 = rough, 0.99 = smooth)
# method taken from `Deep Learning with Python` by François Chollet
def smooth_curve(points, factor=0.75):
smoothed_points = []
for point in points:
if smoothed_points:
previous = smoothed_points[-1]
smoothed_points.append(previous * factor + point * (1 - factor))
else:
smoothed_points.append(point)
return smoothed_points
def set_plot_history_data(ax, history, which_graph):
if which_graph == 'acc':
train = smooth_curve(history.history['acc'])
valid = smooth_curve(history.history['val_acc'])
if which_graph == 'loss':
train = smooth_curve(history.history['loss'])
valid = smooth_curve(history.history['val_loss'])
plt.xkcd() # make plots look like xkcd
epochs = range(1, len(train) + 1)
trim = 5 # remove first 5 epochs
# when graphing loss the first few epochs may skew the (loss) graph
ax.plot(epochs[trim:], train[trim:], 'dodgerblue', label=('Training'))
ax.plot(epochs[trim:], train[trim:], 'dodgerblue', linewidth=15, alpha=0.1)
ax.plot(epochs[trim:], valid[trim:], 'g', label=('Validation'))
ax.plot(epochs[trim:], valid[trim:], 'g', linewidth=15, alpha=0.1)
def get_max_validation_accuracy(history):
validation = smooth_curve(history.history['val_acc'])
ymax = max(validation)
return 'Max validation accuracy ≈ ' + str(round(ymax, 3)*100) + '%'
def plot_history(history):
fig, (ax1, ax2) = plt.subplots(nrows=2,
ncols=1,
figsize=(10, 6),
sharex=True,
gridspec_kw = {'height_ratios':[5, 2]})
set_plot_history_data(ax1, history, 'acc')
set_plot_history_data(ax2, history, 'loss')
# Accuracy graph
ax1.set_ylabel('Accuracy')
ax1.set_ylim(bottom=0.5, top=1)
ax1.legend(loc="lower right")
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
ax1.xaxis.set_ticks_position('none')
ax1.spines['bottom'].set_visible(False)
# max accuracty text
plt.text(0.97,
0.97,
get_max_validation_accuracy(history),
horizontalalignment='right',
verticalalignment='top',
transform=ax1.transAxes,
fontsize=12)
# Loss graph
ax2.set_ylabel('Loss')
ax2.set_yticks([])
ax2.plot(legend=False)
ax2.set_xlabel('Epochs')
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
plt.tight_layout()
# how to use:
# assuming you are using Keras
history = model.fit(x, y, ...)
plot_history(history)
@whyboris
Copy link
Author

Thank you for the bug report. At some point in the last year something changed -- I believe Keras no longer has acc ❓ 🤔

I have this code as a PyPI package: https://pypi.org/project/keras-hist-graph/
And there's a repository for it too: https://github.com/whyboris/keras-hist-graph

I've not been doing much Keras lately, so I might not fix it for a while 🤷 -- if you happen to figure something out - please open a PR 🙇

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment