Last active
January 9, 2023 18:56
-
-
Save almeidaraul/f377bae9323c2ffe957679b200f8eb48 to your computer and use it in GitHub Desktop.
Early stopping auxiliary tool; plots val vs train loss with optional results with early stopping
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Reads CSVs exported from TensorBoard (i.e., with a 'Value' column) to serve | |
as a tool for early stopping analysis (best point to stop, train/val loss | |
comparison, results if stopping at point X, etc) | |
""" | |
import argparse | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
def get_args(): | |
parser = argparse.ArgumentParser( | |
prog="early_stopping_assistant", | |
description="plots and info to choose early stopping manually") | |
parser.add_argument("--train_csv", metavar="train_loss_csv", type=str, | |
required=True, help="Train loss CSV file") | |
parser.add_argument("--val_csv", metavar="val_loss_csv", type=str, | |
required=True, help="Val loss CSV file") | |
parser.add_argument("--result_csv", metavar="result_csv", type=str, | |
help="Result CSV file") | |
parser.add_argument("--plot_diff", action="store_true", | |
help="Plot loss difference as well") | |
parser.add_argument("--early_stop", action="store_true", | |
help="Simulate early stopping") | |
parser.add_argument("--tolerance", type=int, default=3, | |
help="Early stopping tolerance") | |
parser.add_argument("--min_delta", type=float, default=1., | |
help="Early stopping min delta") | |
return parser.parse_args() | |
def get_early_stop(loss_diff_series, tolerance, min_delta): | |
tracking = False | |
counter = 0 | |
for epoch, delta in enumerate(loss_diff_series): | |
tracking |= 0. <= delta <= min_delta | |
if not tracking: | |
continue | |
counter = 0 if delta <= min_delta else counter + 1 | |
if counter == tolerance: | |
return epoch - tolerance + 1 | |
return len(loss_diff_series) - tolerance | |
args = get_args() | |
df_train = pd.read_csv(args.train_csv) | |
df_val = pd.read_csv(args.val_csv) | |
df_result = pd.read_csv(args.result_csv) if args.result_csv else None | |
df_diff = pd.DataFrame({ | |
"step": df_train.Step, "train": df_train.Value, "val": df_val.Value, | |
"diff": df_val.Value-df_train.Value}) | |
early_stop = 0 | |
if args.early_stop: | |
early_stop = get_early_stop(df_diff["diff"], args.tolerance, | |
args.min_delta) | |
plot_axes = ["train", "val"] | |
plot_legend = ["Train loss", "Validation loss"] | |
if args.plot_diff: | |
plot_axes.append("diff") | |
plot_legend.append("(Validation - Train) loss") | |
title = "Loss per epoch" | |
df_diff.plot(x="step", y=plot_axes, title=title, xlabel="Epoch") | |
plt.legend(plot_legend) | |
if args.early_stop: | |
plt.axvline(x=early_stop, color="red", linestyle="dotted", linewidth=1, | |
label=f"Early stop point ({early_stop})") | |
plt.axvline(x=early_stop + args.tolerance, color="gray", | |
linestyle="dashdot", linewidth=1, | |
label=f"Decision point ({early_stop + args.tolerance})") | |
plt.legend() | |
if args.early_stop and args.result_csv: | |
es = df_result['Value'][early_stop] | |
mn = min(df_result['Value']) | |
ls = df_result['Value'][len(df_result['Value'])-1] | |
plt.suptitle(f"best: {mn:.5f}, last: {ls:.5f}, es: {es:.5f}") | |
plt.savefig("plot.png") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment