Skip to content

Instantly share code, notes, and snippets.

@apsdehal
Last active April 24, 2021 09:56
Show Gist options
  • Save apsdehal/1d348bb5c1a179c430f9676d4c2972f0 to your computer and use it in GitHub Desktop.
Save apsdehal/1d348bb5c1a179c430f9676d4c2972f0 to your computer and use it in GitHub Desktop.
This can be used to plot a train.log from MMF using plotly. Update `JOBS_BASEPATH` to point to your save folders and `METRIC` variable to point to metric that you want to plot.
import sys
import os
import numpy as np
import json
from collections import defaultdict
import seaborn
import glob
import random
import plotly
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
import plotly.io as pio
JOBS_BASEPATH = ["<path_to_save_folder>"]
METRIC = "val/vqa2/vqa_accuracy"
def read_logline(line):
line = line.strip()
if "{\"" in line:
info_index = line.find("{\"")
line = line[info_index:]
res = json.loads(line)
else:
info_index = line.find("progress")
line = line[info_index:]
res = {
split[0].strip(): split[1].strip()
for item in line.split(",")
for split in [item.strip().split(":")]
}
if "val/total_loss" in res:
new_res = {}
for k, v in res.items():
if "val" not in k:
k = "val/" + k
new_res[k] = v
res = new_res
return res
def process_log(exp_out, label=None):
res = defaultdict(list)
filename = exp_out
if label is None:
label = filename.split(os.path.sep)[-2]
with open(filename) as f:
for line in f:
if "epoch" in line and "loss" in line:
r = read_logline(line)
for k, v in r.items():
res[k].append(v)
return {
"res": res,
"label": label,
}
return res
files = []
for path in JOBS_BASEPATH:
path = os.path.join(path, "train.log")
files += glob.glob(path, recursive=True)
results = {e + str(i): process_log(e) for i, e in enumerate(files)}
plt_data = []
for i, r in results.items():
g = go.Scatter(
x=r["res"]["num_updates"],
y=r["res"][METRIC],
name=r["label"]
)
plt_data.append(g)
layout = dict(
xaxis=dict(title="num_updates"),
yaxis=dict(title=METRIC),
font=dict(
size=8
),
legend=dict(x=-.1, y=1.3),
hoverlabel = dict(namelength = -1)
)
fig = dict(data=plt_data, layout=layout)
iplot(fig, show_link=False)
@dinhanhx
Copy link

Sir, I try to run it with this log file, and I got this error

JSONDecodeError                           Traceback (most recent call last)
<ipython-input-1-7d360b433507> in <module>
     54     path = os.path.join(path, "train.log")
     55     files += glob.glob(path, recursive=True)
---> 56 results = {e + str(i): process_log(e) for i, e in enumerate(files)}
     57 
     58 

<ipython-input-1-7d360b433507> in <dictcomp>(.0)
     54     path = os.path.join(path, "train.log")
     55     files += glob.glob(path, recursive=True)
---> 56 results = {e + str(i): process_log(e) for i, e in enumerate(files)}
     57 
     58 

<ipython-input-1-7d360b433507> in process_log(exp_out, label)
     40         for line in f:
     41             if "epoch" in line and "loss" in line:
---> 42                 r = read_logline(line)
     43                 for k, v in r.items():
     44                     res[k].append(v)

<ipython-input-1-7d360b433507> in read_logline(line)
     21     info_index = line.find("{\"")
     22     line = line[info_index:]
---> 23     res = json.loads(line)
     24     if "val/total_loss" in res:
     25         new_res = {}

~/miniconda3/envs/mmf/lib/python3.7/json/__init__.py in loads(s, encoding, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw)
    346             parse_int is None and parse_float is None and
    347             parse_constant is None and object_pairs_hook is None and not kw):
--> 348         return _default_decoder.decode(s)
    349     if cls is None:
    350         cls = JSONDecoder

~/miniconda3/envs/mmf/lib/python3.7/json/decoder.py in decode(self, s, _w)
    335 
    336         """
--> 337         obj, end = self.raw_decode(s, idx=_w(s, 0).end())
    338         end = _w(s, end).end()
    339         if end != len(s):

~/miniconda3/envs/mmf/lib/python3.7/json/decoder.py in raw_decode(self, s, idx)
    353             obj, end = self.scan_once(s, idx)
    354         except StopIteration as err:
--> 355             raise JSONDecodeError("Expecting value", s, err.value) from None
    356         return obj, end

JSONDecodeError: Expecting value: line 1 column 1 (char 0)

@apsdehal
Copy link
Author

@dinhanhx I will edit this in the morning (PST) to support normal format as well.

@apsdehal
Copy link
Author

@dinhanhx Can you try now?

@dinhanhx
Copy link

@dinhanhx Can you try now?

It works perfectly. Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment