Skip to content

Instantly share code, notes, and snippets.

@apsdehal
Last active April 24, 2021 09:56
Show Gist options
  • Save apsdehal/1d348bb5c1a179c430f9676d4c2972f0 to your computer and use it in GitHub Desktop.
Save apsdehal/1d348bb5c1a179c430f9676d4c2972f0 to your computer and use it in GitHub Desktop.
This can be used to plot a train.log from MMF using plotly. Update `JOBS_BASEPATH` to point to your save folders and `METRIC` variable to point to metric that you want to plot.
import sys
import os
import numpy as np
import json
from collections import defaultdict
import seaborn
import glob
import random
import plotly
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
import plotly.io as pio
JOBS_BASEPATH = ["<path_to_save_folder>"]
METRIC = "val/vqa2/vqa_accuracy"
def read_logline(line):
line = line.strip()
if "{\"" in line:
info_index = line.find("{\"")
line = line[info_index:]
res = json.loads(line)
else:
info_index = line.find("progress")
line = line[info_index:]
res = {
split[0].strip(): split[1].strip()
for item in line.split(",")
for split in [item.strip().split(":")]
}
if "val/total_loss" in res:
new_res = {}
for k, v in res.items():
if "val" not in k:
k = "val/" + k
new_res[k] = v
res = new_res
return res
def process_log(exp_out, label=None):
res = defaultdict(list)
filename = exp_out
if label is None:
label = filename.split(os.path.sep)[-2]
with open(filename) as f:
for line in f:
if "epoch" in line and "loss" in line:
r = read_logline(line)
for k, v in r.items():
res[k].append(v)
return {
"res": res,
"label": label,
}
return res
files = []
for path in JOBS_BASEPATH:
path = os.path.join(path, "train.log")
files += glob.glob(path, recursive=True)
results = {e + str(i): process_log(e) for i, e in enumerate(files)}
plt_data = []
for i, r in results.items():
g = go.Scatter(
x=r["res"]["num_updates"],
y=r["res"][METRIC],
name=r["label"]
)
plt_data.append(g)
layout = dict(
xaxis=dict(title="num_updates"),
yaxis=dict(title=METRIC),
font=dict(
size=8
),
legend=dict(x=-.1, y=1.3),
hoverlabel = dict(namelength = -1)
)
fig = dict(data=plt_data, layout=layout)
iplot(fig, show_link=False)
@dinhanhx
Copy link

@dinhanhx Can you try now?

It works perfectly. Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment