Last active
November 14, 2019 21:03
-
-
Save thomasbrandon/f03aa9e1d7856ab5086f135e83a12b42 to your computer and use it in GitHub Desktop.
Match layers between fastai and torchvision models
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"Collapsed": "false" | |
}, | |
"outputs": [], | |
"source": [ | |
"from fastai.vision import *\n", | |
"import torchvision" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"Collapsed": "false" | |
}, | |
"outputs": [], | |
"source": [ | |
"data = (ImageList.from_folder(untar_data(URLs.IMAGEWOOF_320))\n", | |
" .filter_by_rand(0.1).split_by_folder(valid='val')\n", | |
" .label_from_folder().transform(size=224).databunch(num_workers=0))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"Collapsed": "false" | |
}, | |
"outputs": [], | |
"source": [ | |
"def iter_model_tree(model:nn.Module, pre_order:bool=True, visit_root:bool=True):\n", | |
" \"\"\"Iterate a model tree yielding (name,module) pairs.\"\"\"\n", | |
" mod_path = [model]\n", | |
" iter_path = [iter(model.named_children())]\n", | |
" name_path = []\n", | |
" if visit_root and pre_order: yield '',model\n", | |
" while mod_path:\n", | |
" cur_mod,cur_it = mod_path[-1],iter_path[-1]\n", | |
" try:\n", | |
" name,mod = next(cur_it)\n", | |
" name_path.append(name)\n", | |
" mod_path.append(mod)\n", | |
" iter_path.append(iter(mod.named_children()))\n", | |
" if pre_order and (visit_root or mod is not model):\n", | |
" yield '.'.join(name_path), mod\n", | |
" except StopIteration:\n", | |
" mod = mod_path.pop()\n", | |
" if mod is model and not visit_root: return\n", | |
" if not pre_order: yield '.'.join(name_path), mod\n", | |
" iter_path.pop()\n", | |
" if len(name_path) > 0: name_path.pop()\n", | |
" if visit_root and not pre_order: yield '',model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": { | |
"Collapsed": "false" | |
}, | |
"outputs": [], | |
"source": [ | |
"orig_mdl = torchvision.models.resnet34()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"Collapsed": "false" | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"['conv1.weight',\n", | |
" 'bn1.weight',\n", | |
" 'bn1.bias',\n", | |
" 'bn1.running_mean',\n", | |
" 'bn1.running_var',\n", | |
" 'bn1.num_batches_tracked',\n", | |
" 'layer1.0.conv1.weight',\n", | |
" 'layer1.0.bn1.weight',\n", | |
" 'layer1.0.bn1.bias',\n", | |
" 'layer1.0.bn1.running_mean']" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"list(orig_mdl.state_dict().keys())[:10]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"Collapsed": "false" | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"['0.0.weight',\n", | |
" '0.1.weight',\n", | |
" '0.1.bias',\n", | |
" '0.1.running_mean',\n", | |
" '0.1.running_var',\n", | |
" '0.1.num_batches_tracked',\n", | |
" '0.4.0.conv1.weight',\n", | |
" '0.4.0.bn1.weight',\n", | |
" '0.4.0.bn1.bias',\n", | |
" '0.4.0.bn1.running_mean']" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"learn = cnn_learner(data, models.resnet34)\n", | |
"list(learn.model.state_dict().keys())[:10]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"Collapsed": "false" | |
}, | |
"outputs": [], | |
"source": [ | |
"orig_layers,new_layers = (\n", | |
" [(n,mod)\n", | |
" for n,mod in iter_model_tree(mdl, visit_root=False)\n", | |
" if isinstance(mod,(nn.Conv2d,nn.BatchNorm2d))\n", | |
" ] for mdl in (orig_mdl,learn.model))" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 19, | |
"metadata": { | |
"Collapsed": "false" | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"[('conv1', '0.0'),\n", | |
" ('bn1', '0.1'),\n", | |
" ('layer1.0.conv1', '0.4.0.conv1'),\n", | |
" ('layer1.0.bn1', '0.4.0.bn1'),\n", | |
" ('layer1.0.conv2', '0.4.0.conv2'),\n", | |
" ('layer1.0.bn2', '0.4.0.bn2'),\n", | |
" ('layer1.1.conv1', '0.4.1.conv1'),\n", | |
" ('layer1.1.bn1', '0.4.1.bn1'),\n", | |
" ('layer1.1.conv2', '0.4.1.conv2'),\n", | |
" ('layer1.1.bn2', '0.4.1.bn2')]" | |
] | |
}, | |
"execution_count": 19, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"mapping = {n1:n2 for (n1,m1),(n2,m2) in zip(orig_layers,new_layers)}\n", | |
"list(mapping.items())[:10]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 24, | |
"metadata": { | |
"Collapsed": "false" | |
}, | |
"outputs": [], | |
"source": [ | |
"orig_state = torch.utils.model_zoo.load_url(torchvision.models.resnet.model_urls['resnet34'])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 31, | |
"metadata": { | |
"Collapsed": "false" | |
}, | |
"outputs": [], | |
"source": [ | |
"new_state = {mapping[n[:p]]+n[p:]: s\n", | |
" for n,p,s in ((n,n.rfind('.'),s) for n,s in orig_state.items())\n", | |
" if n[:p] in mapping}" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 33, | |
"metadata": { | |
"Collapsed": "false" | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"_IncompatibleKeys(missing_keys=['1.2.weight', '1.2.bias', '1.2.running_mean', '1.2.running_var', '1.4.weight', '1.4.bias', '1.6.weight', '1.6.bias', '1.6.running_mean', '1.6.running_var', '1.8.weight', '1.8.bias'], unexpected_keys=[])" | |
] | |
}, | |
"execution_count": 33, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"learn.model.load_state_dict(new_state, strict=False)" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "fastai-dev-venv", | |
"language": "python", | |
"name": "fastai-dev-venv" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.3" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment