Skip to content

Instantly share code, notes, and snippets.

@whyboris
Last active May 12, 2021 02:47
Show Gist options
  • Star 7 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save whyboris/91ee793ddc92cf1e824978cf31bb790c to your computer and use it in GitHub Desktop.
Save whyboris/91ee793ddc92cf1e824978cf31bb790c to your computer and use it in GitHub Desktop.
Keras Loss & Accuracy Plot Helper Function
import matplotlib.pyplot as plt
# Plot model history more easily
# when plotting, smooth out the points by some factor (0.5 = rough, 0.99 = smooth)
# method taken from `Deep Learning with Python` by François Chollet
def smooth_curve(points, factor=0.75):
smoothed_points = []
for point in points:
if smoothed_points:
previous = smoothed_points[-1]
smoothed_points.append(previous * factor + point * (1 - factor))
else:
smoothed_points.append(point)
return smoothed_points
def set_plot_history_data(ax, history, which_graph):
if which_graph == 'acc':
train = smooth_curve(history.history['acc'])
valid = smooth_curve(history.history['val_acc'])
if which_graph == 'loss':
train = smooth_curve(history.history['loss'])
valid = smooth_curve(history.history['val_loss'])
plt.xkcd() # make plots look like xkcd
epochs = range(1, len(train) + 1)
trim = 5 # remove first 5 epochs
# when graphing loss the first few epochs may skew the (loss) graph
ax.plot(epochs[trim:], train[trim:], 'dodgerblue', label=('Training'))
ax.plot(epochs[trim:], train[trim:], 'dodgerblue', linewidth=15, alpha=0.1)
ax.plot(epochs[trim:], valid[trim:], 'g', label=('Validation'))
ax.plot(epochs[trim:], valid[trim:], 'g', linewidth=15, alpha=0.1)
def get_max_validation_accuracy(history):
validation = smooth_curve(history.history['val_acc'])
ymax = max(validation)
return 'Max validation accuracy ≈ ' + str(round(ymax, 3)*100) + '%'
def plot_history(history):
fig, (ax1, ax2) = plt.subplots(nrows=2,
ncols=1,
figsize=(10, 6),
sharex=True,
gridspec_kw = {'height_ratios':[5, 2]})
set_plot_history_data(ax1, history, 'acc')
set_plot_history_data(ax2, history, 'loss')
# Accuracy graph
ax1.set_ylabel('Accuracy')
ax1.set_ylim(bottom=0.5, top=1)
ax1.legend(loc="lower right")
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
ax1.xaxis.set_ticks_position('none')
ax1.spines['bottom'].set_visible(False)
# max accuracty text
plt.text(0.97,
0.97,
get_max_validation_accuracy(history),
horizontalalignment='right',
verticalalignment='top',
transform=ax1.transAxes,
fontsize=12)
# Loss graph
ax2.set_ylabel('Loss')
ax2.set_yticks([])
ax2.plot(legend=False)
ax2.set_xlabel('Epochs')
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
plt.tight_layout()
# how to use:
# assuming you are using Keras
history = model.fit(x, y, ...)
plot_history(history)
@whyboris
Copy link
Author

whyboris commented Jul 24, 2018

Results in a graph like this:

image

@whyboris
Copy link
Author

whyboris commented Aug 7, 2018

This is now a pip package: https://pypi.org/project/keras-hist-graph/
Install with pip install keras-hist-graph
Use thus:

from keras_hist_graph import plot_history

history = model.fit(x, y, ...) # standard Keras training code

plot_history(history)

@nilspeder
Copy link

This is now a pip package: https://pypi.org/project/keras-hist-graph/
Install with pip install keras-hist-graph
Use thus:

from keras_hist_graph import plot_history

history = model.fit(x, y, ...) # standard Keras training code

plot_history(history)

Getting this error:

KeyError Traceback (most recent call last)
in
1 from keras_hist_graph import plot_history
2
----> 3 plot_history(history)

~\anaconda3\envs\tf_gpu\lib\site-packages\keras_hist_graph\keras_hist_graph.py in plot_history(history, start_epoch, smooth_factor, xkcd, fig_size, min_accuracy)
68 )
69
---> 70 set_plot_history_data(ax1, history, "acc", start_epoch, smooth_factor, xkcd)
71
72 set_plot_history_data(ax2, history, "loss", start_epoch, smooth_factor, xkcd)

~\anaconda3\envs\tf_gpu\lib\site-packages\keras_hist_graph\keras_hist_graph.py in set_plot_history_data(ax, history, which_graph, start_epoch, smooth_factor, xkcd)
21
22 if which_graph == "acc":
---> 23 train = smooth_curve(history.history["acc"], smooth_factor)
24 valid = smooth_curve(history.history["val_acc"], smooth_factor)
25

KeyError: 'acc'

@whyboris
Copy link
Author

Thank you for the bug report. At some point in the last year something changed -- I believe Keras no longer has acc ❓ 🤔

I have this code as a PyPI package: https://pypi.org/project/keras-hist-graph/
And there's a repository for it too: https://github.com/whyboris/keras-hist-graph

I've not been doing much Keras lately, so I might not fix it for a while 🤷 -- if you happen to figure something out - please open a PR 🙇

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment