Skip to content

Instantly share code, notes, and snippets.

@JamesOwers
Last active October 18, 2019 14:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JamesOwers/04a877c4f0ab474ca1d4fd8f32e0110b to your computer and use it in GitHub Desktop.
Save JamesOwers/04a877c4f0ab474ca1d4fd8f32e0110b to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import display\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"from baselines.get_results import (construct_parser, main, \n",
" plot_confusion, round_to_n, plot_log_file)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# CONFIG - EDIT THESE\n",
"\n",
"# location of logs and model checkpoints\n",
"output_dir = 'output'\n",
"\n",
"# location of the pianoroll and command corpus datasets\n",
"in_dir = 'acme'\n",
"\n",
"# Names of tasks (folders in output_dir)\n",
"task_names = ['task1', 'task1weighted', 'task2', 'task3',\n",
" 'task3weighted', 'task4']\n",
"\n",
"# The names of variables gridsearched over for each task\n",
"sn = ['lr', 'wd', 'hid']\n",
"# setting_names = [sn, sn, sn, sn, sn, sn]\n",
"setting_names = [sn, sn, sn, sn, sn, ['lr', 'wd', 'hid', 'lay']]\n",
"\n",
"\n",
"formats = list({\n",
" 'task1': 'command',\n",
" 'task1weighted': 'command',\n",
" 'task2': 'command',\n",
" 'task3': 'pianoroll',\n",
" 'task3weighted': 'pianoroll',\n",
" 'task4': 'pianoroll',\n",
"}.values())\n",
"seq_len = list({\n",
" 'task1': '1000',\n",
" 'task1weighted': '1000',\n",
" 'task2': '1000',\n",
" 'task3': '250',\n",
" 'task3weighted': '250',\n",
" 'task4': '250',\n",
"}.values())\n",
"metrics = list({\n",
" 'task1': 'rev_f',\n",
" 'task1weighted': 'rev_f',\n",
" 'task2': 'avg_acc',\n",
" 'task3': 'f',\n",
" 'task3weighted': 'f',\n",
" 'task4': 'helpfulness',\n",
"}.values())\n",
"task_desc = list({\n",
" 'task1': 'Error Detection',\n",
" 'task1weighted': 'Error Detection',\n",
" 'task2': 'Error Classification',\n",
" 'task3': 'Error Location',\n",
" 'task3weighted': 'Error Location',\n",
" 'task4': 'Error Correction',\n",
"}.values())\n",
"task_desc = [tt.replace(' ', '') for tt in task_desc]\n",
"args_str = (\n",
" f\"--output_dir {output_dir} \"\n",
" f\"--save_plots {output_dir} \"\n",
" f\"--in_dir {in_dir} \"\n",
" f\"--task_names {' '.join(task_names)} \"\n",
" f\"--setting_names {' '.join([str(sn).replace(' ', '') for sn in setting_names])} \"\n",
" f\"--formats {' '.join(formats)} \"\n",
" f\"--seq_len {' '.join(seq_len)} \"\n",
" f\"--metrics {' '.join(metrics)} \"\n",
" f\"--task_desc {' '.join(task_desc)} \"\n",
" f\"--splits train valid test\"\n",
"# f\"--splits test\"\n",
")\n",
"args_str"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"parser = construct_parser()\n",
"args = parser.parse_args(args_str.split())\n",
"results, min_idx, task_eval_log, res_df, summary_tab_, confusion = main(args)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"task_name = 'task1'\n",
"display(results[task_name])\n",
"display(min_idx[task_name])\n",
"display(task_eval_log[task_name])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"res_df.loc[task_name].dropna(axis=1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for ii, task_name in enumerate(task_names):\n",
" print(task_name)\n",
" df = (\n",
" res_df.loc[task_name]\n",
" .dropna(axis=1) # removes cols with na in them (not a metric for this task)\n",
" )\n",
" if 'confusion_mat' in df.columns:\n",
" confusion_mat = df['confusion_mat'][0]\n",
" plot_confusion(confusion_mat)\n",
" df.drop('confusion_mat', axis=1, inplace=True)\n",
" df = (\n",
" df\n",
" .apply(pd.to_numeric) # they are strings, convert to int or float\n",
" .applymap(round_to_n) # round to 3sf\n",
" )\n",
" display(df)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"summary_tab_"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"confusion "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Run this if errors to:\n",
"# * stop tqdm multiline printing\n",
"# * reload packages if you need to change them\n",
"# from tqdm import tqdm\n",
"# tqdm._instances.clear()\n",
"\n",
"import importlib\n",
"import baselines\n",
"import baselines.get_results\n",
"import mdtk\n",
"importlib.reload(baselines.get_results)\n",
"# importlib.reload(baselines.eval_task)\n",
"# importlib.reload(mdtk.formatters)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"best_models = {task: f'{output_dir}/{task}/{val[1]}.checkpoint.best'\n",
" for task, val in min_idx.items()}\n",
"best_logs = {task: f'{output_dir}/{task}/{val[1]}.log'\n",
" for task, val in min_idx.items()}\n",
"save_plots = False\n",
"print(f\"best models: {best_models}\")\n",
"for task_name, log_file in best_logs.items():\n",
" plot_log_file(log_file)\n",
" plt.title(f'{task_name} best model training curve')\n",
" if save_plots:\n",
" plt.savefig(f'{save_plots}/{task_name}__best_model_loss.png',\n",
" dpi=300)\n",
" plt.savefig(f'{save_plots}/{task_name}__best_model_loss.pdf',\n",
" dpi=300)\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment