Skip to content

Instantly share code, notes, and snippets.

@kurtlawrence
Last active May 17, 2023 03:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kurtlawrence/7921fca8751ce2ea4826f8ab0a90ee32 to your computer and use it in GitHub Desktop.
Save kurtlawrence/7921fca8751ce2ea4826f8ab0a90ee32 to your computer and use it in GitHub Desktop.
Construct HTML document with Plotly charts of burn-rs training progression
//! A small utility to construct a HTML document with Plotly charts of the training metrics
//! by reading the log entries in burn's artifact directory.
//!
//! The file can be compiled and executed with `rustc` (`-O` for optimised)
//! ```sh
//! rustc chart-metrics.rs && ./chart-metrics <ARTIFACT-DIR>
//! ```
//!
//! Source code: https://gist.github.com/kurtlawrence/7921fca8751ce2ea4826f8ab0a90ee32
use std::collections::BTreeMap;
use std::fs;
use std::path::{Path, PathBuf};
type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
fn main() -> Result<()> {
let mut args = std::env::args().skip(1).collect::<Vec<_>>();
let agg_epoch = args
.iter()
.enumerate()
.find_map(|(i, x)| (x == "--avg-epoch").then_some(i));
if let Some(x) = agg_epoch {
args.remove(x);
}
let agg_epoch = agg_epoch.is_some();
let mut args = args.into_iter();
let artifacts_dir = args.next().map(PathBuf::from).ok_or_else(|| {
print_usage();
"expecting a burn artifacts directory"
})?;
let output_file = args.next().map(PathBuf::from);
let metrics = read_dir(&artifacts_dir)?;
match &output_file {
Some(x) => eprintln!("Constructing HTML report at `{}`", x.display()),
None => eprintln!("Constructing HTML report"),
}
let injection = metrics.into_iter().fold(String::new(), |x, (k, v)| {
x + "\n\n" + &plot_html(k, v, agg_epoch)
});
let html = HTML.replace("{{inject}}", &injection);
match output_file {
None => println!("{html}"),
Some(x) => fs::write(x, html)?,
}
Ok(())
}
fn print_usage() {
eprintln!("Read training logs in a burn artifact directory and output a HTML file with plots");
eprintln!("usage: <ARTIFACT-DIR> [<OUTPUT-FILE>] [--avg-epoch]");
eprintln!("--avg-epoch: optionally average the metric per epoch");
}
struct Metric {
epoch: u32,
value: f64,
stg: Stage,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
enum Stage {
Train,
Valid,
}
type Map = BTreeMap<String, Vec<Metric>>;
fn read_dir(path: &Path) -> Result<Map> {
let map = read_stage(BTreeMap::new(), path.join("train").as_ref(), Stage::Train)?;
let mut map = read_stage(map, path.join("valid").as_ref(), Stage::Valid)?;
// sort by epoch, cycling between train and valid
map.values_mut()
.for_each(|vs| vs.sort_by(|a, b| a.epoch.cmp(&b.epoch)));
Ok(map)
}
fn read_stage(mut map: Map, path: &Path, stg: Stage) -> Result<Map> {
for d in path.read_dir()? {
let d = d?;
let Some(epoch) = d.file_name().to_str().and_then(|x| x.strip_prefix("epoch-")).and_then(|x| x.parse::<u32>().ok()) else { continue; };
eprintln!("Reading metrics for {stg:?} epoch-{epoch}");
for d in d.path().read_dir()? {
let p = d?.path();
let Some(name) = p.file_stem().and_then(|x| x.to_str()).map(ToString::to_string) else { continue; };
let es = map.entry(name).or_default();
parse_values(&p, |value| es.push(Metric { epoch, stg, value }))?;
}
}
Ok(map)
}
fn parse_values<F: FnMut(f64)>(file: &Path, cb: F) -> Result<()> {
let x = fs::read_to_string(file)?;
x.lines().filter_map(|x| x.parse::<f64>().ok()).for_each(cb);
Ok(())
}
fn plot_html(metric: String, values: Vec<Metric>, avg_epoch: bool) -> String {
use std::fmt::Write;
let mut s = format!(
r#"
<h2>Plot of {metric} metric</h2>
<div id="{metric}" style="height: 600px;"></div>
<script>
"#
);
let (ts, vs): (Vec<(usize, f64)>, Vec<(usize, f64)>) = if avg_epoch {
Box::new(
values
.into_iter()
.fold(BTreeMap::new(), |mut map, m| {
let e: &mut Vec<f64> = map.entry((m.epoch, m.stg)).or_default();
e.push(m.value);
map
})
.into_iter()
.map(|((epoch, stg), vs)| {
let value = if vs.is_empty() {
0.
} else {
vs.iter().sum::<f64>() / vs.len() as f64
};
(epoch as usize, Metric { epoch, stg, value })
}),
) as Box<dyn Iterator<Item = (usize, Metric)>>
} else {
// we accumulate the iteration as we go
Box::new(values.into_iter().enumerate())
}
.fold((Vec::new(), Vec::new()), |(mut ts, mut vs), (x, m)| {
match m.stg {
Stage::Train => ts.push((x, m.value)),
Stage::Valid => vs.push((x, m.value)),
}
(ts, vs)
});
let (xst, yst): (Vec<_>, Vec<_>) = ts.into_iter().unzip();
let (xsv, ysv): (Vec<_>, Vec<_>) = vs.into_iter().unzip();
writeln!(
&mut s,
r#"var train = {{
x: {xst:?},
y: {yst:?},
mode: mode,
type: 'scatter',
name: 'Train'
}};
var valid = {{
x: {xsv:?},
y: {ysv:?},
mode: mode,
type: 'scatter',
name: 'Valid'
}};
var id = document.getElementById('{metric}');
plots.push(id);
Plotly.newPlot(id, [train,valid], {{ xaxis: {{ title: '{title}' }} }});
</script>
"#,
title = if avg_epoch { "Epoch" } else { "Iteration" }
)
.ok();
s
}
const HTML: &str = r#"
<!DOCTYPE=HTML>
<html>
<head>
<script src="https://cdn.plot.ly/plotly-2.20.0.min.js" charset="utf-8"></script>
<script>
var mode = 'markers';
var plots = [];
function switchMode(el) {
if (mode == 'markers') {
mode = 'lines';
el.innerText = 'Markers';
} else {
mode = 'markers';
el.innerText = 'Lines';
}
plots.forEach(p => Plotly.restyle(p, { mode: mode }));
}
</script>
</head>
<body>
<button onclick="switchMode(this)">Lines</button>
{{inject}}
</body>
</html>
"#;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment