Skip to content

Instantly share code, notes, and snippets.

@johschmidt42
Last active January 14, 2021 17:20
Show Gist options
  • Save johschmidt42/8c38d657cf616629a7763187ce836029 to your computer and use it in GitHub Desktop.
Save johschmidt42/8c38d657cf616629a7763187ce836029 to your computer and use it in GitHub Desktop.
def plot_training(training_losses,
validation_losses,
learning_rate,
gaussian=True,
sigma=2,
figsize=(8, 6)
):
"""
Returns a loss plot with training loss, validation loss and learning rate.
"""
import matplotlib.pyplot as plt
from matplotlib import gridspec
from scipy.ndimage import gaussian_filter
list_len = len(training_losses)
x_range = list(range(1, list_len + 1)) # number of x values
fig = plt.figure(figsize=figsize)
grid = gridspec.GridSpec(ncols=2, nrows=1, figure=fig)
subfig1 = fig.add_subplot(grid[0, 0])
subfig2 = fig.add_subplot(grid[0, 1])
subfigures = fig.get_axes()
for i, subfig in enumerate(subfigures, start=1):
subfig.spines['top'].set_visible(False)
subfig.spines['right'].set_visible(False)
if gaussian:
training_losses_gauss = gaussian_filter(training_losses, sigma=sigma)
validation_losses_gauss = gaussian_filter(validation_losses, sigma=sigma)
linestyle_original = '.'
color_original_train = 'lightcoral'
color_original_valid = 'lightgreen'
color_smooth_train = 'red'
color_smooth_valid = 'green'
alpha = 0.25
else:
linestyle_original = '-'
color_original_train = 'red'
color_original_valid = 'green'
alpha = 1.0
# Subfig 1
subfig1.plot(x_range, training_losses, linestyle_original, color=color_original_train, label='Training',
alpha=alpha)
subfig1.plot(x_range, validation_losses, linestyle_original, color=color_original_valid, label='Validation',
alpha=alpha)
if gaussian:
subfig1.plot(x_range, training_losses_gauss, '-', color=color_smooth_train, label='Training', alpha=0.75)
subfig1.plot(x_range, validation_losses_gauss, '-', color=color_smooth_valid, label='Validation', alpha=0.75)
subfig1.title.set_text('Training & validation loss')
subfig1.set_xlabel('Epoch')
subfig1.set_ylabel('Loss')
subfig1.legend(loc='upper right')
# Subfig 2
subfig2.plot(x_range, learning_rate, color='black')
subfig2.title.set_text('Learning rate')
subfig2.set_xlabel('Epoch')
subfig2.set_ylabel('LR')
return fig
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment