Skip to content

Instantly share code, notes, and snippets.

@sheikmohdimran
Created November 17, 2019 17:10
Show Gist options
  • Save sheikmohdimran/7d551820d33c10af3aa2493e579b81f2 to your computer and use it in GitHub Desktop.
Save sheikmohdimran/7d551820d33c10af3aa2493e579b81f2 to your computer and use it in GitHub Desktop.
knowledge_distillation_test.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "knowledge_distillation_test.ipynb",
"provenance": [],
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/sheikmohdimran/7d551820d33c10af3aa2493e579b81f2/knowledge_distillation_test.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "RvrHXXKNeDBd",
"colab_type": "code",
"colab": {}
},
"source": [
"import pandas as pd\n",
"from fastai.vision import *"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "arZyAhqoeDuY",
"colab_type": "code",
"colab": {}
},
"source": [
"path = untar_data(URLs.MNIST )\n",
"#path.ls()"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "-6_WADnQDwO5",
"colab_type": "code",
"colab": {}
},
"source": [
"tfms = get_transforms(do_flip=False)\n",
"data = (ImageList.from_folder(path/'training')\n",
" .split_by_rand_pct()\n",
" .label_from_folder()#.add_test_folder(path/'testing')\n",
" .transform(tfms, size=26).databunch())"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "xUfJ4VqyIrEM",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 326
},
"outputId": "d1ddf05d-10c6-4e02-849e-50bad722f800"
},
"source": [
"data"
],
"execution_count": 48,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"ImageDataBunch;\n",
"\n",
"Train: LabelList (48000 items)\n",
"x: ImageList\n",
"Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26)\n",
"y: CategoryList\n",
"8,8,8,8,8\n",
"Path: /root/.fastai/data/mnist_png/training;\n",
"\n",
"Valid: LabelList (12000 items)\n",
"x: ImageList\n",
"Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26)\n",
"y: CategoryList\n",
"2,7,5,4,4\n",
"Path: /root/.fastai/data/mnist_png/training;\n",
"\n",
"Test: None"
]
},
"metadata": {
"tags": []
},
"execution_count": 48
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "QAwXVKM8CG7K",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 143
},
"outputId": "e19acb12-9b7f-4ca3-cffa-696aceb973ca"
},
"source": [
"teach_learn = cnn_learner(data, models.resnet18, metrics=accuracy)\n",
"teach_learn.fit_one_cycle(3)"
],
"execution_count": 5,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>0</td>\n",
" <td>0.760325</td>\n",
" <td>0.453046</td>\n",
" <td>0.853667</td>\n",
" <td>01:12</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1</td>\n",
" <td>0.414666</td>\n",
" <td>0.200453</td>\n",
" <td>0.934167</td>\n",
" <td>01:13</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2</td>\n",
" <td>0.313441</td>\n",
" <td>0.169848</td>\n",
" <td>0.946417</td>\n",
" <td>01:13</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "LWcQNXHP2xW_",
"colab_type": "code",
"colab": {}
},
"source": [
"### Training set ###\n",
"# Run inference \n",
"preds,y = teach_learn.get_preds(ds_type=DatasetType.Train)\n",
"# Get predicted classes\n",
"_, preds_class = preds.max(1)\n",
"# Get Filenames\n",
"df=pd.DataFrame({ 'file':teach_learn.data.train_ds.items})\n",
"# Add true labels\n",
"df['y'] = pd.DataFrame(y.numpy())\n",
"# Add predictions\n",
"df['y_teach'] = pd.DataFrame(preds_class.numpy())\n",
"# Combine true label and prediction as tuple\n",
"df['y_tuple'] = df[['y','y_teach']].apply(tuple, axis=1)\n",
"# Drop True labels and prediction column\n",
"df = df.drop(df.columns[[1, 2]], axis=1) "
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "4nfGvSKJB50e",
"colab_type": "code",
"colab": {}
},
"source": [
"### Validation set ###\n",
"# Run inference \n",
"preds,y = teach_learn.get_preds(ds_type=DatasetType.Valid)\n",
"# Get predicted classes\n",
"_, preds_class = preds.max(1)\n",
"# Get Filenames\n",
"df1=pd.DataFrame({ 'file':teach_learn.data.valid_ds.items})\n",
"# Add true labels\n",
"df1['y'] = pd.DataFrame(y.numpy())\n",
"# Add predictions\n",
"df1['y_teach'] = pd.DataFrame(preds_class.numpy())\n",
"# Combine true label and prediction as tuple\n",
"df1['y_tuple'] = df1[['y','y_teach']].apply(tuple, axis=1)\n",
"# Drop True labels and prediction column\n",
"df1 = df1.drop(df1.columns[[1, 2]], axis=1) "
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "NIeJeRP6CLny",
"colab_type": "code",
"colab": {}
},
"source": [
"final_df=pd.concat([df, df1])"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "T3d2z67lF0ba",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 206
},
"outputId": "966e5402-bda3-42f9-9fa3-ca802ce8112a"
},
"source": [
"final_df.head()"
],
"execution_count": 14,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>file</th>\n",
" <th>y_tuple</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>/root/.fastai/data/mnist_png/training/8/54198.png</td>\n",
" <td>(6, 6)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>/root/.fastai/data/mnist_png/training/8/15021.png</td>\n",
" <td>(0, 0)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>/root/.fastai/data/mnist_png/training/8/23092.png</td>\n",
" <td>(1, 1)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>/root/.fastai/data/mnist_png/training/8/30401.png</td>\n",
" <td>(6, 6)</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>/root/.fastai/data/mnist_png/training/8/57745.png</td>\n",
" <td>(1, 1)</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" file y_tuple\n",
"0 /root/.fastai/data/mnist_png/training/8/54198.png (6, 6)\n",
"1 /root/.fastai/data/mnist_png/training/8/15021.png (0, 0)\n",
"2 /root/.fastai/data/mnist_png/training/8/23092.png (1, 1)\n",
"3 /root/.fastai/data/mnist_png/training/8/30401.png (6, 6)\n",
"4 /root/.fastai/data/mnist_png/training/8/57745.png (1, 1)"
]
},
"metadata": {
"tags": []
},
"execution_count": 14
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "9wm9eY2ZmXBj",
"colab_type": "code",
"colab": {}
},
"source": [
"tfms = get_transforms(do_flip=False)\n",
"data = (ImageList.from_df(final_df,'/')\n",
" .split_by_rand_pct()\n",
" .label_from_df(cols='y_tuple')\n",
" .transform(tfms, size=26).databunch())"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "rwzgE5jvJh43",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 326
},
"outputId": "68096f91-714d-48bb-e631-531777191f13"
},
"source": [
"data"
],
"execution_count": 50,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"ImageDataBunch;\n",
"\n",
"Train: LabelList (48000 items)\n",
"x: ImageList\n",
"Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26)\n",
"y: MultiCategoryList\n",
"6;6,0;0,1;1,6;6,1;1\n",
"Path: /;\n",
"\n",
"Valid: LabelList (12000 items)\n",
"x: ImageList\n",
"Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26),Image (3, 26, 26)\n",
"y: MultiCategoryList\n",
"1;1,5;5,7;7,6;6,0;0\n",
"Path: /;\n",
"\n",
"Test: None"
]
},
"metadata": {
"tags": []
},
"execution_count": 50
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "efV5heWPLJ4C",
"colab_type": "code",
"colab": {}
},
"source": [
"def loss_fn_kd(outputs, labels, teacher_outputs):\n",
" alpha = 0.9\n",
" T = 20\n",
" KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),\n",
" F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \\\n",
" F.cross_entropy(outputs, labels) * (1. - alpha)\n",
"\n",
" return KD_loss"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "FHOdkMqTJ1VY",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 418
},
"outputId": "725a2f34-c485-41ca-b7f5-6ea208bf24fc"
},
"source": [
"student_learn = cnn_learner(data, models.resnet18, metrics=accuracy,loss_func=loss_fn_kd)\n",
"student_learn.fit_one_cycle(3)"
],
"execution_count": 98,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='0' class='' max='3', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 0.00% [0/3 00:00<00:00]\n",
" </div>\n",
" \n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" </tbody>\n",
"</table><p>\n",
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='0' class='progress-bar-interrupted' max='750', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" Interrupted\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {
"tags": []
}
},
{
"output_type": "error",
"ename": "TypeError",
"evalue": "ignored",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-98-467782828119>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mstudent_learn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcnn_learner\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodels\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresnet18\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maccuracy\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mloss_func\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mloss_fn_kd\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mstudent_learn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_one_cycle\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/fastai/train.py\u001b[0m in \u001b[0;36mfit_one_cycle\u001b[0;34m(learn, cyc_len, max_lr, moms, div_factor, pct_start, final_div, wd, callbacks, tot_epochs, start_epoch)\u001b[0m\n\u001b[1;32m 21\u001b[0m callbacks.append(OneCycleScheduler(learn, max_lr, moms=moms, div_factor=div_factor, pct_start=pct_start,\n\u001b[1;32m 22\u001b[0m final_div=final_div, tot_epochs=tot_epochs, start_epoch=start_epoch))\n\u001b[0;32m---> 23\u001b[0;31m \u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcyc_len\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_lr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwd\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwd\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 25\u001b[0m def fit_fc(learn:Learner, tot_epochs:int=1, lr:float=defaults.lr, moms:Tuple[float,float]=(0.95,0.85), start_pct:float=0.72,\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/fastai/basic_train.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, epochs, lr, wd, callbacks)\u001b[0m\n\u001b[1;32m 198\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwd\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mwd\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 199\u001b[0m \u001b[0mcallbacks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mcb\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mcb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallback_fns\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mlistify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdefaults\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mextra_callback_fns\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mlistify\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 200\u001b[0;31m \u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepochs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetrics\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallbacks\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m+\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 201\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 202\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcreate_opt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mFloats\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mwd\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mFloats\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m->\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/fastai/basic_train.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(epochs, learn, callbacks, metrics)\u001b[0m\n\u001b[1;32m 99\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0myb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mprogress_bar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_dl\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpbar\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 100\u001b[0m \u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcb_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_begin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 101\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlearn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcb_handler\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 102\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mcb_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mon_batch_end\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mbreak\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 103\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.6/dist-packages/fastai/basic_train.py\u001b[0m in \u001b[0;36mloss_batch\u001b[0;34m(model, xb, yb, loss_func, opt, cb_handler)\u001b[0m\n\u001b[1;32m 28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mloss_func\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mto_detach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mto_detach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0myb\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 30\u001b[0;31m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0myb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 31\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 32\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mopt\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mTypeError\u001b[0m: loss_fn_kd() missing 1 required positional argument: 'teacher_outputs'"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "O_6jj3tBOmfh",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment