Skip to content

Instantly share code, notes, and snippets.

@noachr
Last active March 18, 2019 23:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save noachr/60647063c0daaaaf41bcfe4f1ae48335 to your computer and use it in GitHub Desktop.
Save noachr/60647063c0daaaaf41bcfe4f1ae48335 to your computer and use it in GitHub Desktop.
Frozen output issue
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%reload_ext autoreload\n",
"%autoreload 2\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from fastai.vision import *"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"path = untar_data(URLs.IMAGENETTE_160) #Just picked a small dataset for example, but problem persists with all datsets I've tried"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"data = (ImageList.from_folder(path)\n",
" .split_by_folder(valid=\"val\")\n",
" .label_from_folder()\n",
" .transform(size=128)\n",
" .databunch(bs=64))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"class WTTest(nn.Module):\n",
" def __init__(self,num_classes):\n",
" super().__init__()\n",
" self.base = create_body(models.alexnet)\n",
" for p in self.base.parameters(): p.requires_grad = False\n",
" self.head = create_head(256*2,num_classes)\n",
" \n",
" def forward(self, x):\n",
" return self.head(self.base(x))"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"#Create learner. Disable update of BN stats just in case, but alexnet has no BN layers\n",
"learn = Learner(data,WTTest(data.c),metrics=accuracy,train_bn=False,callback_fns=[BnFreeze])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"#Get a batch from the validation set\n",
"x,y = next(iter(data.valid_dl))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"#Save base state dict and output activations from x before training\n",
"prev_sd = learn.model.base.state_dict()\n",
"prev_output = learn.model.base(x)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"Total time: 00:05 <p><table style='width:375px; margin-bottom:10px'>\n",
" <tr>\n",
" <th>epoch</th>\n",
" <th>train_loss</th>\n",
" <th>valid_loss</th>\n",
" <th>accuracy</th>\n",
" <th>time</th>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <th>0.693029</th>\n",
" <th>0.505424</th>\n",
" <th>0.854000</th>\n",
" <th>00:05</th>\n",
" </tr>\n",
"</table>\n"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#Train model for one epoch\n",
"learn.fit_one_cycle(1,1e-3)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"#Save base state dict and output activations from x after training\n",
"post_sd = learn.model.base.state_dict()\n",
"post_output = learn.model.base(x)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0, device='cuda:0', dtype=torch.uint8)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#Compare prev and post activations -- 0 indicates they are not equal\n",
"torch.all(prev_output == post_output)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"#Compare weights, no output indicates they are indeed equal\n",
"for key in prev_sd.keys():\n",
" if not torch.equal(prev_sd[key],post_sd[key]):\n",
" print(key)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## How can the input and weights stay constant, but the output changes?"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"======================================================================\n",
"Layer (type) Output Shape Param # Trainable \n",
"======================================================================\n",
"Conv2d [1, 64, 31, 31] 23,296 False \n",
"______________________________________________________________________\n",
"ReLU [1, 64, 31, 31] 0 False \n",
"______________________________________________________________________\n",
"MaxPool2d [1, 64, 15, 15] 0 False \n",
"______________________________________________________________________\n",
"Conv2d [1, 192, 15, 15] 307,392 False \n",
"______________________________________________________________________\n",
"ReLU [1, 192, 15, 15] 0 False \n",
"______________________________________________________________________\n",
"MaxPool2d [1, 192, 7, 7] 0 False \n",
"______________________________________________________________________\n",
"Conv2d [1, 384, 7, 7] 663,936 False \n",
"______________________________________________________________________\n",
"ReLU [1, 384, 7, 7] 0 False \n",
"______________________________________________________________________\n",
"Conv2d [1, 256, 7, 7] 884,992 False \n",
"______________________________________________________________________\n",
"ReLU [1, 256, 7, 7] 0 False \n",
"______________________________________________________________________\n",
"Conv2d [1, 256, 7, 7] 590,080 False \n",
"______________________________________________________________________\n",
"ReLU [1, 256, 7, 7] 0 False \n",
"______________________________________________________________________\n",
"MaxPool2d [1, 256, 3, 3] 0 False \n",
"______________________________________________________________________\n",
"AdaptiveAvgPool2d [1, 256, 1, 1] 0 False \n",
"______________________________________________________________________\n",
"AdaptiveMaxPool2d [1, 256, 1, 1] 0 False \n",
"______________________________________________________________________\n",
"Flatten [1, 512] 0 False \n",
"______________________________________________________________________\n",
"BatchNorm1d [1, 512] 1,024 True \n",
"______________________________________________________________________\n",
"Dropout [1, 512] 0 False \n",
"______________________________________________________________________\n",
"Linear [1, 512] 262,656 True \n",
"______________________________________________________________________\n",
"ReLU [1, 512] 0 False \n",
"______________________________________________________________________\n",
"BatchNorm1d [1, 512] 1,024 True \n",
"______________________________________________________________________\n",
"Dropout [1, 512] 0 False \n",
"______________________________________________________________________\n",
"Linear [1, 10] 5,130 True \n",
"______________________________________________________________________\n",
"\n",
"Total params: 2,739,530\n",
"Total trainable params: 269,834\n",
"Total non-trainable params: 2,469,696"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"#Summary shows the alexnet is indeed frozen\n",
"learn.summary()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#Output does change, but not by much. The more training iterations that happen, the more it changes. "
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(3.4483, device='cuda:0')"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prev_output[0,1,0,0]"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(3.4488, device='cuda:0')"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"post_output[0,1,0,0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Try without fastai learner"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"model = WTTest(data.c).cuda()\n",
"loss = nn.CrossEntropyLoss()\n",
"opt = optim.Adam(model.parameters(),lr=1e-3)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"x,y = next(iter(data.valid_dl))"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"prev_sd = model.base.state_dict()\n",
"prev_output = model.base(x)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"#Train one epoch\n",
"for input,labels in iter(data.train_dl):\n",
" opt.zero_grad()\n",
" output = model(input)\n",
" l = loss(output,labels)\n",
" l.backward()\n",
" opt.step()"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"post_sd = model.base.state_dict()\n",
"post_output = model.base(x)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(1, device='cuda:0', dtype=torch.uint8)"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.all(prev_output == post_output)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment