Created
November 17, 2019 17:10
-
-
Save sheikmohdimran/7d551820d33c10af3aa2493e579b81f2 to your computer and use it in GitHub Desktop.
knowledge_distillation_test.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"name": "knowledge_distillation_test.ipynb", | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/sheikmohdimran/7d551820d33c10af3aa2493e579b81f2/knowledge_distillation_test.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "RvrHXXKNeDBd", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import pandas as pd\n", | |
"from fastai.vision import *" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "arZyAhqoeDuY", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"path = untar_data(URLs.MNIST )\n", | |
"#path.ls()" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "-6_WADnQDwO5", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"tfms = get_transforms(do_flip=False)\n", | |
"data = (ImageList.from_folder(path/'training')\n", | |
" .split_by_rand_pct()\n", | |
" .label_from_folder()#.add_test_folder(path/'testing')\n", | |
" .transform(tfms, size=26).databunch())" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "xUfJ4VqyIrEM", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 326 | |
}, | |
"outputId": "d1ddf05d-10c6-4e02-849e-50bad722f800" | |
}, | |
"source": [ | |
"data" | |
], | |
"execution_count": 48, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"ImageDataBunch;\n", | |
"\n", | |
"Train: LabelList (48000 items)\n", | |
"x: ImageList\n", | |
"Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26)\n", | |
"y: CategoryList\n", | |
"8,8,8,8,8\n", | |
"Path: /root/.fastai/data/mnist_png/training;\n", | |
"\n", | |
"Valid: LabelList (12000 items)\n", | |
"x: ImageList\n", | |
"Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26)\n", | |
"y: CategoryList\n", | |
"2,7,5,4,4\n", | |
"Path: /root/.fastai/data/mnist_png/training;\n", | |
"\n", | |
"Test: None" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 48 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "QAwXVKM8CG7K", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 143 | |
}, | |
"outputId": "e19acb12-9b7f-4ca3-cffa-696aceb973ca" | |
}, | |
"source": [ | |
"teach_learn = cnn_learner(data, models.resnet18, metrics=accuracy)\n", | |
"teach_learn.fit_one_cycle(3)" | |
], | |
"execution_count": 5, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>accuracy</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <td>0</td>\n", | |
" <td>0.760325</td>\n", | |
" <td>0.453046</td>\n", | |
" <td>0.853667</td>\n", | |
" <td>01:12</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>1</td>\n", | |
" <td>0.414666</td>\n", | |
" <td>0.200453</td>\n", | |
" <td>0.934167</td>\n", | |
" <td>01:13</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <td>2</td>\n", | |
" <td>0.313441</td>\n", | |
" <td>0.169848</td>\n", | |
" <td>0.946417</td>\n", | |
" <td>01:13</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "LWcQNXHP2xW_", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"### Training set ###\n", | |
"# Run inference \n", | |
"preds,y = teach_learn.get_preds(ds_type=DatasetType.Train)\n", | |
"# Get predicted classes\n", | |
"_, preds_class = preds.max(1)\n", | |
"# Get Filenames\n", | |
"df=pd.DataFrame({ 'file':teach_learn.data.train_ds.items})\n", | |
"# Add true labels\n", | |
"df['y'] = pd.DataFrame(y.numpy())\n", | |
"# Add predictions\n", | |
"df['y_teach'] = pd.DataFrame(preds_class.numpy())\n", | |
"# Combine true label and prediction as tuple\n", | |
"df['y_tuple'] = df[['y','y_teach']].apply(tuple, axis=1)\n", | |
"# Drop True labels and prediction column\n", | |
"df = df.drop(df.columns[[1, 2]], axis=1) " | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "4nfGvSKJB50e", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"### Validation set ###\n", | |
"# Run inference \n", | |
"preds,y = teach_learn.get_preds(ds_type=DatasetType.Valid)\n", | |
"# Get predicted classes\n", | |
"_, preds_class = preds.max(1)\n", | |
"# Get Filenames\n", | |
"df1=pd.DataFrame({ 'file':teach_learn.data.valid_ds.items})\n", | |
"# Add true labels\n", | |
"df1['y'] = pd.DataFrame(y.numpy())\n", | |
"# Add predictions\n", | |
"df1['y_teach'] = pd.DataFrame(preds_class.numpy())\n", | |
"# Combine true label and prediction as tuple\n", | |
"df1['y_tuple'] = df1[['y','y_teach']].apply(tuple, axis=1)\n", | |
"# Drop True labels and prediction column\n", | |
"df1 = df1.drop(df1.columns[[1, 2]], axis=1) " | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "NIeJeRP6CLny", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"final_df=pd.concat([df, df1])" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "T3d2z67lF0ba", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 206 | |
}, | |
"outputId": "966e5402-bda3-42f9-9fa3-ca802ce8112a" | |
}, | |
"source": [ | |
"final_df.head()" | |
], | |
"execution_count": 14, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>file</th>\n", | |
" <th>y_tuple</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <td>/root/.fastai/data/mnist_png/training/8/54198.png</td>\n", | |
" <td>(6, 6)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>1</th>\n", | |
" <td>/root/.fastai/data/mnist_png/training/8/15021.png</td>\n", | |
" <td>(0, 0)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>2</th>\n", | |
" <td>/root/.fastai/data/mnist_png/training/8/23092.png</td>\n", | |
" <td>(1, 1)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>3</th>\n", | |
" <td>/root/.fastai/data/mnist_png/training/8/30401.png</td>\n", | |
" <td>(6, 6)</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>4</th>\n", | |
" <td>/root/.fastai/data/mnist_png/training/8/57745.png</td>\n", | |
" <td>(1, 1)</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" file y_tuple\n", | |
"0 /root/.fastai/data/mnist_png/training/8/54198.png (6, 6)\n", | |
"1 /root/.fastai/data/mnist_png/training/8/15021.png (0, 0)\n", | |
"2 /root/.fastai/data/mnist_png/training/8/23092.png (1, 1)\n", | |
"3 /root/.fastai/data/mnist_png/training/8/30401.png (6, 6)\n", | |
"4 /root/.fastai/data/mnist_png/training/8/57745.png (1, 1)" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 14 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "9wm9eY2ZmXBj", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"tfms = get_transforms(do_flip=False)\n", | |
"data = (ImageList.from_df(final_df,'/')\n", | |
" .split_by_rand_pct()\n", | |
" .label_from_df(cols='y_tuple')\n", | |
" .transform(tfms, size=26).databunch())" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "rwzgE5jvJh43", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 326 | |
}, | |
"outputId": "68096f91-714d-48bb-e631-531777191f13" | |
}, | |
"source": [ | |
"data" | |
], | |
"execution_count": 50, | |
"outputs": [ | |
{ | |
"output_type": "execute_result", | |
"data": { | |
"text/plain": [ | |
"ImageDataBunch;\n", | |
"\n", | |
"Train: LabelList (48000 items)\n", | |
"x: ImageList\n", | |
"Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26)\n", | |
"y: MultiCategoryList\n", | |
"6;6,0;0,1;1,6;6,1;1\n", | |
"Path: /;\n", | |
"\n", | |
"Valid: LabelList (12000 items)\n", | |
"x: ImageList\n", | |
"Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26)\n", | |
"y: MultiCategoryList\n", | |
"1;1,5;5,7;7,6;6,0;0\n", | |
"Path: /;\n", | |
"\n", | |
"Test: None" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
}, | |
"execution_count": 50 | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "efV5heWPLJ4C", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def loss_fn_kd(outputs, labels, teacher_outputs):\n", | |
" alpha = 0.9\n", | |
" T = 20\n", | |
" KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),\n", | |
" F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \\\n", | |
" F.cross_entropy(outputs, labels) * (1. - alpha)\n", | |
"\n", | |
" return KD_loss" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "FHOdkMqTJ1VY", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 418 | |
}, | |
"outputId": "725a2f34-c485-41ca-b7f5-6ea208bf24fc" | |
}, | |
"source": [ | |
"student_learn = cnn_learner(data, models.resnet18, metrics=accuracy,loss_func=loss_fn_kd)\n", | |
"student_learn.fit_one_cycle(3)" | |
], | |
"execution_count": 98, | |
"outputs": [ | |
{ | |
"output_type": "display_data", | |
"data": { | |
"text/html": [ | |
"\n", | |
" <div>\n", | |
" <style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
" </style>\n", | |
" <progress value='0' class='' max='3', style='width:300px; height:20px; vertical-align: middle;'></progress>\n", | |
" 0.00% [0/3 00:00<00:00]\n", | |
" </div>\n", | |
" \n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: left;\">\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>accuracy</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" </tbody>\n", | |
"</table><p>\n", | |
"\n", | |
" <div>\n", | |
" <style>\n", | |
" /* Turns off some styling */\n", | |
" progress {\n", | |
" /* gets rid of default border in Firefox and Opera. */\n", | |
" border: none;\n", | |
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n", | |
" background-size: auto;\n", | |
" }\n", | |
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n", | |
" background: #F44336;\n", | |
" }\n", | |
" </style>\n", | |
" <progress value='0' class='progress-bar-interrupted' max='750', style='width:300px; height:20px; vertical-align: middle;'></progress>\n", | |
" Interrupted\n", | |
" </div>\n", | |
" " | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": { | |
"tags": [] | |
} | |
}, | |
{ | |
"output_type": "error", | |
"ename": "TypeError", | |
"evalue": "ignored", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-98-467782828119>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mstudent_learn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcnn_learner\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresnet18\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maccuracy\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mloss_func\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mloss_fn_kd\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mstudent_learn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_one_cycle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/fastai/train.py\u001b[0m in \u001b[0;36mfit_one_cycle\u001b[0;34m(learn, cyc_len, max_lr, moms, div_factor, pct_start, final_div, wd, callbacks, tot_epochs, start_epoch)\u001b[0m\n\u001b[1;32m 21\u001b[0m callbacks.append(OneCycleScheduler(learn, max_lr, moms=moms, div_factor=div_factor, pct_start=pct_start,\n\u001b[1;32m 22\u001b[0m final_div=final_div, tot_epochs=tot_epochs, start_epoch=start_epoch))\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcyc_len\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_lr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwd\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m def fit_fc(learn:Learner, tot_epochs:int=1, lr:float=defaults.lr, moms:Tuple[float,float]=(0.95,0.85), start_pct:float=0.72,\n", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/fastai/basic_train.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, epochs, lr, wd, callbacks)\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mwd\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0mcallbacks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mcb\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallback_fns\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mlistify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdefaults\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mextra_callback_fns\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mlistify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m \u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepochs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 201\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcreate_opt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mFloats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwd\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mFloats\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m->\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/fastai/basic_train.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(epochs, learn, callbacks, metrics)\u001b[0m\n\u001b[1;32m 99\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0myb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mprogress_bar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_dl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpbar\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 100\u001b[0m \u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcb_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 101\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcb_handler\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 102\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcb_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/fastai/basic_train.py\u001b[0m in \u001b[0;36mloss_batch\u001b[0;34m(model, xb, yb, loss_func, opt, cb_handler)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mloss_func\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mto_detach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mto_detach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0myb\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0myb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mopt\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mTypeError\u001b[0m: loss_fn_kd() missing 1 required positional argument: 'teacher_outputs'" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "O_6jj3tBOmfh", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment