Last active
March 18, 2019 23:14
-
-
Save noachr/60647063c0daaaaf41bcfe4f1ae48335 to your computer and use it in GitHub Desktop.
Frozen output issue
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"%reload_ext autoreload\n", | |
"%autoreload 2\n", | |
"%matplotlib inline" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from fastai.vision import *" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"path = untar_data(URLs.IMAGENETTE_160) #Just picked a small dataset for example, but problem persists with all datsets I've tried" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"data = (ImageList.from_folder(path)\n", | |
" .split_by_folder(valid=\"val\")\n", | |
" .label_from_folder()\n", | |
" .transform(size=128)\n", | |
" .databunch(bs=64))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class WTTest(nn.Module):\n", | |
" def __init__(self,num_classes):\n", | |
" super().__init__()\n", | |
" self.base = create_body(models.alexnet)\n", | |
" for p in self.base.parameters(): p.requires_grad = False\n", | |
" self.head = create_head(256*2,num_classes)\n", | |
" \n", | |
" def forward(self, x):\n", | |
" return self.head(self.base(x))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#Create learner. Disable update of BN stats just in case, but alexnet has no BN layers\n", | |
"learn = Learner(data,WTTest(data.c),metrics=accuracy,train_bn=False,callback_fns=[BnFreeze])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#Get a batch from the validation set\n", | |
"x,y = next(iter(data.valid_dl))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#Save base state dict and output activations from x before training\n", | |
"prev_sd = learn.model.base.state_dict()\n", | |
"prev_output = learn.model.base(x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"Total time: 00:05 <p><table style='width:375px; margin-bottom:10px'>\n", | |
" <tr>\n", | |
" <th>epoch</th>\n", | |
" <th>train_loss</th>\n", | |
" <th>valid_loss</th>\n", | |
" <th>accuracy</th>\n", | |
" <th>time</th>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>0</th>\n", | |
" <th>0.693029</th>\n", | |
" <th>0.505424</th>\n", | |
" <th>0.854000</th>\n", | |
" <th>00:05</th>\n", | |
" </tr>\n", | |
"</table>\n" | |
], | |
"text/plain": [ | |
"<IPython.core.display.HTML object>" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"#Train model for one epoch\n", | |
"learn.fit_one_cycle(1,1e-3)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#Save base state dict and output activations from x after training\n", | |
"post_sd = learn.model.base.state_dict()\n", | |
"post_output = learn.model.base(x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(0, device='cuda:0', dtype=torch.uint8)" | |
] | |
}, | |
"execution_count": 11, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"#Compare prev and post activations -- 0 indicates they are not equal\n", | |
"torch.all(prev_output == post_output)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#Compare weights, no output indicates they are indeed equal\n", | |
"for key in prev_sd.keys():\n", | |
" if not torch.equal(prev_sd[key],post_sd[key]):\n", | |
" print(key)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## How can the input and weights stay constant, but the output changes?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"======================================================================\n", | |
"Layer (type) Output Shape Param # Trainable \n", | |
"======================================================================\n", | |
"Conv2d [1, 64, 31, 31] 23,296 False \n", | |
"______________________________________________________________________\n", | |
"ReLU [1, 64, 31, 31] 0 False \n", | |
"______________________________________________________________________\n", | |
"MaxPool2d [1, 64, 15, 15] 0 False \n", | |
"______________________________________________________________________\n", | |
"Conv2d [1, 192, 15, 15] 307,392 False \n", | |
"______________________________________________________________________\n", | |
"ReLU [1, 192, 15, 15] 0 False \n", | |
"______________________________________________________________________\n", | |
"MaxPool2d [1, 192, 7, 7] 0 False \n", | |
"______________________________________________________________________\n", | |
"Conv2d [1, 384, 7, 7] 663,936 False \n", | |
"______________________________________________________________________\n", | |
"ReLU [1, 384, 7, 7] 0 False \n", | |
"______________________________________________________________________\n", | |
"Conv2d [1, 256, 7, 7] 884,992 False \n", | |
"______________________________________________________________________\n", | |
"ReLU [1, 256, 7, 7] 0 False \n", | |
"______________________________________________________________________\n", | |
"Conv2d [1, 256, 7, 7] 590,080 False \n", | |
"______________________________________________________________________\n", | |
"ReLU [1, 256, 7, 7] 0 False \n", | |
"______________________________________________________________________\n", | |
"MaxPool2d [1, 256, 3, 3] 0 False \n", | |
"______________________________________________________________________\n", | |
"AdaptiveAvgPool2d [1, 256, 1, 1] 0 False \n", | |
"______________________________________________________________________\n", | |
"AdaptiveMaxPool2d [1, 256, 1, 1] 0 False \n", | |
"______________________________________________________________________\n", | |
"Flatten [1, 512] 0 False \n", | |
"______________________________________________________________________\n", | |
"BatchNorm1d [1, 512] 1,024 True \n", | |
"______________________________________________________________________\n", | |
"Dropout [1, 512] 0 False \n", | |
"______________________________________________________________________\n", | |
"Linear [1, 512] 262,656 True \n", | |
"______________________________________________________________________\n", | |
"ReLU [1, 512] 0 False \n", | |
"______________________________________________________________________\n", | |
"BatchNorm1d [1, 512] 1,024 True \n", | |
"______________________________________________________________________\n", | |
"Dropout [1, 512] 0 False \n", | |
"______________________________________________________________________\n", | |
"Linear [1, 10] 5,130 True \n", | |
"______________________________________________________________________\n", | |
"\n", | |
"Total params: 2,739,530\n", | |
"Total trainable params: 269,834\n", | |
"Total non-trainable params: 2,469,696" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"#Summary shows the alexnet is indeed frozen\n", | |
"learn.summary()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#Output does change, but not by much. The more training iterations that happen, the more it changes. " | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 38, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(3.4483, device='cuda:0')" | |
] | |
}, | |
"execution_count": 38, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"prev_output[0,1,0,0]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 37, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(3.4488, device='cuda:0')" | |
] | |
}, | |
"execution_count": 37, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"post_output[0,1,0,0]" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"### Try without fastai learner" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 22, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model = WTTest(data.c).cuda()\n", | |
"loss = nn.CrossEntropyLoss()\n", | |
"opt = optim.Adam(model.parameters(),lr=1e-3)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 23, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"x,y = next(iter(data.valid_dl))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"prev_sd = model.base.state_dict()\n", | |
"prev_output = model.base(x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 28, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"#Train one epoch\n", | |
"for input,labels in iter(data.train_dl):\n", | |
" opt.zero_grad()\n", | |
" output = model(input)\n", | |
" l = loss(output,labels)\n", | |
" l.backward()\n", | |
" opt.step()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 29, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"post_sd = model.base.state_dict()\n", | |
"post_output = model.base(x)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 30, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"tensor(1, device='cuda:0', dtype=torch.uint8)" | |
] | |
}, | |
"execution_count": 30, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"torch.all(prev_output == post_output)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.2" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment