Skip to content

Instantly share code, notes, and snippets.

@gurvindersingh
Last active December 30, 2018 20:10
Show Gist options
  • Save gurvindersingh/9b6b2da01ac6319d2a3ae9bd30f4b758 to your computer and use it in GitHub Desktop.
Save gurvindersingh/9b6b2da01ac6319d2a3ae9bd30f4b758 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%reload_ext autoreload\n",
"%autoreload 2\n",
"%matplotlib inline\n",
"\n",
"from fastai import *\n",
"from fastai.tabular import *\n",
"from sklearn.metrics import confusion_matrix"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"path = Path('data/ham10000')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"df = pickle.load(open(path/'df.pkl', 'rb'))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>lesion_id</th>\n",
" <th>image_id</th>\n",
" <th>dx</th>\n",
" <th>dx_type</th>\n",
" <th>age</th>\n",
" <th>sex</th>\n",
" <th>localization</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>9687</th>\n",
" <td>HAM_0002644</td>\n",
" <td>ISIC_0029417</td>\n",
" <td>akiec</td>\n",
" <td>histo</td>\n",
" <td>80.0</td>\n",
" <td>female</td>\n",
" <td>neck</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9688</th>\n",
" <td>HAM_0006002</td>\n",
" <td>ISIC_0029915</td>\n",
" <td>akiec</td>\n",
" <td>histo</td>\n",
" <td>50.0</td>\n",
" <td>female</td>\n",
" <td>face</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9689</th>\n",
" <td>HAM_0000549</td>\n",
" <td>ISIC_0029360</td>\n",
" <td>akiec</td>\n",
" <td>histo</td>\n",
" <td>70.0</td>\n",
" <td>male</td>\n",
" <td>upper extremity</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9690</th>\n",
" <td>HAM_0000549</td>\n",
" <td>ISIC_0026152</td>\n",
" <td>akiec</td>\n",
" <td>histo</td>\n",
" <td>70.0</td>\n",
" <td>male</td>\n",
" <td>upper extremity</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9691</th>\n",
" <td>HAM_0000673</td>\n",
" <td>ISIC_0029659</td>\n",
" <td>akiec</td>\n",
" <td>histo</td>\n",
" <td>70.0</td>\n",
" <td>female</td>\n",
" <td>face</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" lesion_id image_id dx dx_type age sex localization\n",
"9687 HAM_0002644 ISIC_0029417 akiec histo 80.0 female neck\n",
"9688 HAM_0006002 ISIC_0029915 akiec histo 50.0 female face\n",
"9689 HAM_0000549 ISIC_0029360 akiec histo 70.0 male upper extremity\n",
"9690 HAM_0000549 ISIC_0026152 akiec histo 70.0 male upper extremity\n",
"9691 HAM_0000673 ISIC_0029659 akiec histo 70.0 female face"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"cat_names = ['sex', 'localization']\n",
"cont_names = ['age']\n",
"dep_var = 'dx'\n",
"procs = [FillMissing, Categorify, Normalize]"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"test = TabularList.from_df(df, path=path, cat_names=cat_names, cont_names=cont_names, procs=procs)\n",
"data = LabelLists.load_empty(path).add_test(test).databunch()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"learn = tabular_learner(data, layers=[128,64], metrics=accuracy)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"learn.load('tabular');"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['akiec', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc']"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data.valid_ds.y.classes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Predict using learn.get_preds()\n",
"\n",
"Prediciton is happening on GPU on Test dataframe"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"probs, _ = learn.get_preds(DatasetType.Test)\n",
"preds = probs.argmax(dim=1)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([2, 2, 5, 5, 2, 2, 2, 5, 5, 2, 2, 5, 4, 2, 2, 5, 2, 2, 5, 2, 5, 2, 5, 5,\n",
" 5, 5, 2, 2, 2, 2, 2, 5, 5, 5, 5, 5, 5, 2, 2, 5])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preds"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"p = ['akiec' for x in preds if x == 2]\n",
"p += ['mel' for y in preds if y == 4]\n",
"p += ['nv' for y in preds if y == 5]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[20, 0, 0, 0],\n",
" [ 0, 0, 1, 19],\n",
" [ 0, 0, 0, 0],\n",
" [ 0, 0, 0, 0]])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"confusion_matrix(df.dx, p)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Predict using learn.predict()\n",
"\n",
"Prediciton is happening on CPU."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"preds = df.apply(learn.predict, axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"9687 (nv, tensor(5), [tensor(0.0001), tensor(0.0029...\n",
"9688 (nv, tensor(5), [tensor(0.0001), tensor(0.0026...\n",
"9689 (nv, tensor(5), [tensor(9.1095e-05), tensor(0....\n",
"9690 (nv, tensor(5), [tensor(9.1095e-05), tensor(0....\n",
"9691 (nv, tensor(5), [tensor(0.0001), tensor(0.0028...\n",
"9692 (nv, tensor(5), [tensor(0.0003), tensor(0.0039...\n",
"9693 (nv, tensor(5), [tensor(0.0003), tensor(0.0039...\n",
"9694 (nv, tensor(5), [tensor(5.6556e-05), tensor(0....\n",
"9695 (nv, tensor(5), [tensor(5.6556e-05), tensor(0....\n",
"9696 (nv, tensor(5), [tensor(0.0002), tensor(0.0025...\n",
"9697 (nv, tensor(5), [tensor(0.0002), tensor(0.0025...\n",
"9698 (nv, tensor(5), [tensor(5.6556e-05), tensor(0....\n",
"9699 (nv, tensor(5), [tensor(6.7162e-05), tensor(0....\n",
"9700 (nv, tensor(5), [tensor(0.0002), tensor(0.0025...\n",
"9701 (nv, tensor(5), [tensor(0.0002), tensor(0.0025...\n",
"9702 (nv, tensor(5), [tensor(9.5781e-05), tensor(0....\n",
"9703 (nv, tensor(5), [tensor(0.0002), tensor(0.0051...\n",
"9704 (nv, tensor(5), [tensor(0.0001), tensor(0.0046...\n",
"9705 (nv, tensor(5), [tensor(0.0001), tensor(0.0027...\n",
"9706 (nv, tensor(5), [tensor(6.3124e-05), tensor(0....\n",
"2462 (nv, tensor(5), [tensor(0.0002), tensor(0.0045...\n",
"2463 (nv, tensor(5), [tensor(0.0001), tensor(0.0028...\n",
"2464 (nv, tensor(5), [tensor(5.6529e-05), tensor(0....\n",
"2465 (nv, tensor(5), [tensor(5.6529e-05), tensor(0....\n",
"2466 (nv, tensor(5), [tensor(0.0002), tensor(0.0045...\n",
"2467 (nv, tensor(5), [tensor(0.0002), tensor(0.0045...\n",
"2468 (nv, tensor(5), [tensor(0.0003), tensor(0.0040...\n",
"2469 (nv, tensor(5), [tensor(0.0003), tensor(0.0040...\n",
"2470 (nv, tensor(5), [tensor(0.0003), tensor(0.0040...\n",
"2471 (nv, tensor(5), [tensor(0.0002), tensor(0.0035...\n",
"2472 (nv, tensor(5), [tensor(0.0002), tensor(0.0035...\n",
"2473 (nv, tensor(5), [tensor(0.0002), tensor(0.0049...\n",
"2474 (nv, tensor(5), [tensor(0.0002), tensor(0.0049...\n",
"2475 (nv, tensor(5), [tensor(0.0002), tensor(0.0047...\n",
"2476 (nv, tensor(5), [tensor(0.0002), tensor(0.0050...\n",
"2477 (nv, tensor(5), [tensor(0.0002), tensor(0.0050...\n",
"2478 (nv, tensor(5), [tensor(0.0002), tensor(0.0050...\n",
"2479 (nv, tensor(5), [tensor(0.0001), tensor(0.0022...\n",
"2480 (nv, tensor(5), [tensor(0.0001), tensor(0.0022...\n",
"2481 (nv, tensor(5), [tensor(0.0002), tensor(0.0049...\n",
"dtype: object"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preds"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"p = [x[0].obj for x in preds.values]"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array([[ 0, 0, 20],\n",
" [ 0, 0, 20],\n",
" [ 0, 0, 0]])"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"confusion_matrix(df.dx, p)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment