Skip to content

Instantly share code, notes, and snippets.

@ubless607
Last active March 14, 2022 10:08
Show Gist options
  • Save ubless607/b436e980dfeb7b959b9d490f3f13db35 to your computer and use it in GitHub Desktop.
Save ubless607/b436e980dfeb7b959b9d490f3f13db35 to your computer and use it in GitHub Desktop.
A simple function for plotting a learning curve of the model
def plot_learning_curve(log_df,
metric_name='loss',
rolling=False,
window_size=5,
ylim=(None, None), **kwargs):
'''
A simple function for plotting a learning curve of the model
Args:
log_df: input pandas Dataframe
metric_name: name of the metric to plot
ylim: y-axis limit, Tuple of (bottom, top)
rolling: Defaults to False. If set to True, plot a moving averaged graph in the second figure
window_size: size of the moving window
Reference:
https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.rolling.html
Author:
@ubless607, @yubin8773
'''
# Data from the log.csv
epochs = np.arange(log_df.epoch.iloc[0] + 1, log_df.epoch.iloc[-1] + 2, 1, dtype=np.uint32)
plt.style.use('seaborn-whitegrid')
fig1 = plt.figure(figsize=kwargs.get('fig_size', (8, 4)))
plt.title(f'Learning Curves ({metric_name})')
plt.xlabel('Epoch')
plt.ylabel(f'{metric_name}')
if ylim[0] is not None:
plt.ylim(bottom=ylim[0])
if ylim[1] is not None:
plt.ylim(top=ylim[1])
plt.plot(epochs, log_df[f'{metric_name}'], '-', label='Training')
plt.plot(epochs, log_df[f'val_{metric_name}'], '-', label='Validation')
plt.legend()
plt.tight_layout()
plt.show()
if rolling:
fig2 = plt.figure(figsize=kwargs.get('fig_size', (8, 4)))
loss_mavg = log_df[f'{metric_name}'].rolling(window=window_size).mean()
val_loss_mavg = log_df[f'val_{metric_name}'].rolling(window=window_size).mean()
plt.title(f'Learning Curves ({metric_name}) - moving average')
plt.xlabel(f'Epoch')
plt.ylabel(f'{metric_name}')
if ylim[0] is not None:
plt.ylim(bottom=ylim[0])
if ylim[1] is not None:
plt.ylim(top=ylim[1])
plt.plot(epochs, loss_mavg, '-', label='Training')
plt.plot(epochs, val_loss_mavg, '-', label='Validation')
plt.legend()
plt.tight_layout()
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment