Skip to content

Instantly share code, notes, and snippets.

@KDercksen
Created September 7, 2023 12:09
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save KDercksen/858c6d2f08f41eef062340b26a841a69 to your computer and use it in GitHub Desktop.
Save KDercksen/858c6d2f08f41eef062340b26a841a69 to your computer and use it in GitHub Desktop.
Plot HF log history from trainer_state.json
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from pathlib import Path
import matplotlib.pyplot as plt
from collections import defaultdict
from utils import read_json
import argparse
exclude_keys = [
"epoch",
"step",
"total_flos",
"train_loss",
"train_runtime",
"train_samples_per_second",
"train_steps_per_second",
]
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("trainer_state_path", type=Path)
args = p.parse_args()
log_history = read_json(args.trainer_state_path)["log_history"]
log_values = defaultdict(list)
for item in log_history:
for key in item:
if key not in exclude_keys:
log_values[key].append(item[key])
for key, val in log_values.items():
plt.figure()
plt.title(key)
plt.plot(val)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment