//! A small utility to construct a HTML document with Plotly charts of the training metrics
//! by reading the log entries in burn's artifact directory.
//! The file can be compiled and executed with `rustc` (`-O` for optimised)
//! ```sh
//! rustc && ./chart-metrics <ARTIFACT-DIR>
//! ```
//! Source code:
use std::collections::BTreeMap;
use std::fs;
use std::path::{Path, PathBuf};
type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
fn main() -> Result<()> {
let mut args = std::env::args().skip(1).collect::<Vec<_>>();
let agg_epoch = args
.find_map(|(i, x)| (x == "--avg-epoch").then_some(i));
if let Some(x) = agg_epoch {
let agg_epoch = agg_epoch.is_some();
let mut args = args.into_iter();
let artifacts_dir =|| {
"expecting a burn artifacts directory"
let output_file =;
let metrics = read_dir(&artifacts_dir)?;
match &output_file {
Some(x) => eprintln!("Constructing HTML report at `{}`", x.display()),
None => eprintln!("Constructing HTML report"),
let injection = metrics.into_iter().fold(String::new(), |x, (k, v)| {
x + "\n\n" + &plot_html(k, v, agg_epoch)
let html = HTML.replace("{{inject}}", &injection);
match output_file {
None => println!("{html}"),
Some(x) => fs::write(x, html)?,
fn print_usage() {
eprintln!("Read training logs in a burn artifact directory and output a HTML file with plots");
eprintln!("usage: <ARTIFACT-DIR> [<OUTPUT-FILE>] [--avg-epoch]");
eprintln!("--avg-epoch: optionally average the metric per epoch");
struct Metric {
epoch: u32,
value: f64,
stg: Stage,
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
enum Stage {
type Map = BTreeMap<String, Vec<Metric>>;
fn read_dir(path: &Path) -> Result<Map> {
let map = read_stage(BTreeMap::new(), path.join("train").as_ref(), Stage::Train)?;
let mut map = read_stage(map, path.join("valid").as_ref(), Stage::Valid)?;
// sort by epoch, cycling between train and valid
.for_each(|vs| vs.sort_by(|a, b| a.epoch.cmp(&b.epoch)));
fn read_stage(mut map: Map, path: &Path, stg: Stage) -> Result<Map> {
for d in path.read_dir()? {
let d = d?;
let Some(epoch) = d.file_name().to_str().and_then(|x| x.strip_prefix("epoch-")).and_then(|x| x.parse::<u32>().ok()) else { continue; };
eprintln!("Reading metrics for {stg:?} epoch-{epoch}");
for d in d.path().read_dir()? {
let p = d?.path();
let Some(name) = p.file_stem().and_then(|x| x.to_str()).map(ToString::to_string) else { continue; };
let es = map.entry(name).or_default();
parse_values(&p, |value| es.push(Metric { epoch, stg, value }))?;
fn parse_values<F: FnMut(f64)>(file: &Path, cb: F) -> Result<()> {
let x = fs::read_to_string(file)?;
x.lines().filter_map(|x| x.parse::<f64>().ok()).for_each(cb);
fn plot_html(metric: String, values: Vec<Metric>, avg_epoch: bool) -> String {
use std::fmt::Write;
let mut s = format!(
<h2>Plot of {metric} metric</h2>
<div id="{metric}" style="height: 600px;"></div>
let (ts, vs): (Vec<(usize, f64)>, Vec<(usize, f64)>) = if avg_epoch {
.fold(BTreeMap::new(), |mut map, m| {
let e: &mut Vec<f64> = map.entry((m.epoch, m.stg)).or_default();
.map(|((epoch, stg), vs)| {
let value = if vs.is_empty() {
} else {
vs.iter().sum::<f64>() / vs.len() as f64
(epoch as usize, Metric { epoch, stg, value })
) as Box<dyn Iterator<Item = (usize, Metric)>>
} else {
// we accumulate the iteration as we go
.fold((Vec::new(), Vec::new()), |(mut ts, mut vs), (x, m)| {
match m.stg {
Stage::Train => ts.push((x, m.value)),
Stage::Valid => vs.push((x, m.value)),
(ts, vs)
let (xst, yst): (Vec<_>, Vec<_>) = ts.into_iter().unzip();
let (xsv, ysv): (Vec<_>, Vec<_>) = vs.into_iter().unzip();
&mut s,
r#"var train = {{
x: {xst:?},
y: {yst:?},
mode: mode,
type: 'scatter',
name: 'Train'
var valid = {{
x: {xsv:?},
y: {ysv:?},
mode: mode,
type: 'scatter',
name: 'Valid'
var id = document.getElementById('{metric}');
Plotly.newPlot(id, [train,valid], {{ xaxis: {{ title: '{title}' }} }});
title = if avg_epoch { "Epoch" } else { "Iteration" }
const HTML: &str = r#"
<script src="" charset="utf-8"></script>
var mode = 'markers';
var plots = [];
function switchMode(el) {
if (mode == 'markers') {
mode = 'lines';
el.innerText = 'Markers';
} else {
mode = 'markers';
el.innerText = 'Lines';
plots.forEach(p => Plotly.restyle(p, { mode: mode }));
<button onclick="switchMode(this)">Lines</button>
