Last active
July 18, 2023 00:30
-
-
Save almeidaraul/b0fc84520146fc224fa8523dfe56c371 to your computer and use it in GitHub Desktop.
Loss and errors plotting with Plotly
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"id": "b522a753", | |
"metadata": {}, | |
"source": [ | |
"# Plot tensorboard loss and errors" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "130b4b75", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import numpy as np\n", | |
"import plotly.express as px\n", | |
"import plotly.graph_objects as go\n", | |
"import hashlib\n", | |
"import os\n", | |
"\n", | |
"from tensorboard.backend.event_processing import event_accumulator as eva\n", | |
"from tensorboard.backend.event_processing import event_multiplexer as evm\n", | |
"\n", | |
"TITLE_PREFIX = \"[DCCDN]\"\n", | |
"RUNS_DIR = \"runs_ssan\"\n", | |
"ERROR_NAMES = [\"HTER\"] # [\"ACER\", \"APCER\", \"BPCER\", \"HTER\", \"EER\"]\n", | |
"TO_HTML = False\n", | |
"LOSS_NAMES = {\n", | |
" # dccdn\n", | |
" # \"CDL\": [\"validation/validation_cdl\", \"train_epoch/loss_cdl\"], \"MSE\": [\"validation/validation_mse\", \"train_epoch/loss_mse\"],\n", | |
" # ssan\n", | |
" \"all\": [\"name/loss\", \"train_epoch/loss\"]\n", | |
"}\n", | |
"MAXN = int(1e18)\n", | |
"size_guidance = {\n", | |
" eva.COMPRESSED_HISTOGRAMS: 0,\n", | |
" eva.IMAGES: MAXN,\n", | |
" eva.AUDIO: 0,\n", | |
" eva.SCALARS: MAXN,\n", | |
" eva.HISTOGRAMS: 0,}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "5408c908", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def eventlist_to_array(events):\n", | |
" return np.array(list(map(lambda e: e.value, events)))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "4c0d25df", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def plot_many(arrays, title=\"Losses\", xaxis_title=\"Epoch\", yaxis_title=\"Loss\", to_html=False):\n", | |
" \"\"\"arrays: dict{name: array}, name: run/tag\"\"\"\n", | |
" colors = px.colors.qualitative.Dark24\n", | |
" run_to_idx = {}\n", | |
" ik = 0\n", | |
" \n", | |
" fig = go.Figure()\n", | |
" for name, array in sorted(arrays.items()):\n", | |
" run_name = name.split('/')[0]\n", | |
" if run_name not in run_to_idx:\n", | |
" run_to_idx[run_name] = ik\n", | |
" ik = (ik+1)%len(colors)\n", | |
" color_idx = run_to_idx[run_name]\n", | |
" fig.add_trace(\n", | |
" go.Scatter(\n", | |
" x=list(range(len(array))),\n", | |
" y=array,\n", | |
" mode='lines', name=name,\n", | |
" line=dict(color=colors[color_idx]\n", | |
" )\n", | |
" )\n", | |
" )\n", | |
"\n", | |
" fig.update_layout(title=title, xaxis_title=xaxis_title, yaxis_title=yaxis_title)\n", | |
" if to_html:\n", | |
" fig.write_html(f\"{title}.html\")\n", | |
" else:\n", | |
" fig.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "7e94723e", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def table_many(values, title=\"Errors\", to_html=False):\n", | |
" header = [\"Run\"] + ERROR_NAMES\n", | |
" fig = go.Figure(data=[go.Table(header=dict(values=header),\n", | |
" cells=dict(values=values))\n", | |
" ])\n", | |
" fig.update_layout(title=title)\n", | |
" if to_html:\n", | |
" fig.write_html(\"table.html\")\n", | |
" else:\n", | |
" fig.show()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "a638539d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"run_dirs = os.listdir(RUNS_DIR)\n", | |
"\n", | |
"em = evm.EventMultiplexer({d: f\"{RUNS_DIR}/{d}\" for d in run_dirs}, size_guidance=size_guidance)\n", | |
"em.Reload()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "21df9797", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"losses = {loss_title: {} for loss_title in LOSS_NAMES}\n", | |
"errors = [run_dirs] + [[] for _ in range(len(ERROR_NAMES))]\n", | |
"for run in sorted(run_dirs):\n", | |
" for loss_title, loss_name_list in LOSS_NAMES.items():\n", | |
" for loss in loss_name_list:\n", | |
" losses[loss_title][f\"{run}/{loss}\"] = eventlist_to_array(em.Scalars(run, loss))\n", | |
" error_source = \"test\" # or \"validation\"\n", | |
" for i, error in enumerate(ERROR_NAMES):\n", | |
" name = f\"{error_source}/{error}\"\n", | |
" if name in em.Runs()[run][\"scalars\"]:\n", | |
" v = em.Scalars(run, f'{error_source}/{error}')[0].value\n", | |
" v = v*100\n", | |
" errors[i+1].append(f\"{v:.1f}%\")\n", | |
" else:\n", | |
" errors[i+1].append(float(\"nan\"))\n", | |
"\n", | |
"errors = np.array(errors)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"id": "0124601d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"for loss_title, loss_dict in losses.items():\n", | |
" plot_many(loss_dict, f\"{TITLE_PREFIX} {loss_title} Loss\", to_html=TO_HTML)\n", | |
"table_many(errors, f\"{TITLE_PREFIX} Errors\", to_html=TO_HTML)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.8.10" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment