Skip to content

Instantly share code, notes, and snippets.

@almeidaraul
Last active January 9, 2023 18:56
Show Gist options
  • Save almeidaraul/f377bae9323c2ffe957679b200f8eb48 to your computer and use it in GitHub Desktop.
Save almeidaraul/f377bae9323c2ffe957679b200f8eb48 to your computer and use it in GitHub Desktop.
Early stopping auxiliary tool; plots val vs train loss with optional results with early stopping
"""
Reads CSVs exported from TensorBoard (i.e., with a 'Value' column) to serve
as a tool for early stopping analysis (best point to stop, train/val loss
comparison, results if stopping at point X, etc)
"""
import argparse
import pandas as pd
import matplotlib.pyplot as plt
def get_args():
parser = argparse.ArgumentParser(
prog="early_stopping_assistant",
description="plots and info to choose early stopping manually")
parser.add_argument("--train_csv", metavar="train_loss_csv", type=str,
required=True, help="Train loss CSV file")
parser.add_argument("--val_csv", metavar="val_loss_csv", type=str,
required=True, help="Val loss CSV file")
parser.add_argument("--result_csv", metavar="result_csv", type=str,
help="Result CSV file")
parser.add_argument("--plot_diff", action="store_true",
help="Plot loss difference as well")
parser.add_argument("--early_stop", action="store_true",
help="Simulate early stopping")
parser.add_argument("--tolerance", type=int, default=3,
help="Early stopping tolerance")
parser.add_argument("--min_delta", type=float, default=1.,
help="Early stopping min delta")
return parser.parse_args()
def get_early_stop(loss_diff_series, tolerance, min_delta):
tracking = False
counter = 0
for epoch, delta in enumerate(loss_diff_series):
tracking |= 0. <= delta <= min_delta
if not tracking:
continue
counter = 0 if delta <= min_delta else counter + 1
if counter == tolerance:
return epoch - tolerance + 1
return len(loss_diff_series) - tolerance
args = get_args()
df_train = pd.read_csv(args.train_csv)
df_val = pd.read_csv(args.val_csv)
df_result = pd.read_csv(args.result_csv) if args.result_csv else None
df_diff = pd.DataFrame({
"step": df_train.Step, "train": df_train.Value, "val": df_val.Value,
"diff": df_val.Value-df_train.Value})
early_stop = 0
if args.early_stop:
early_stop = get_early_stop(df_diff["diff"], args.tolerance,
args.min_delta)
plot_axes = ["train", "val"]
plot_legend = ["Train loss", "Validation loss"]
if args.plot_diff:
plot_axes.append("diff")
plot_legend.append("(Validation - Train) loss")
title = "Loss per epoch"
df_diff.plot(x="step", y=plot_axes, title=title, xlabel="Epoch")
plt.legend(plot_legend)
if args.early_stop:
plt.axvline(x=early_stop, color="red", linestyle="dotted", linewidth=1,
label=f"Early stop point ({early_stop})")
plt.axvline(x=early_stop + args.tolerance, color="gray",
linestyle="dashdot", linewidth=1,
label=f"Decision point ({early_stop + args.tolerance})")
plt.legend()
if args.early_stop and args.result_csv:
es = df_result['Value'][early_stop]
mn = min(df_result['Value'])
ls = df_result['Value'][len(df_result['Value'])-1]
plt.suptitle(f"best: {mn:.5f}, last: {ls:.5f}, es: {es:.5f}")
plt.savefig("plot.png")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment