Skip to content

Instantly share code, notes, and snippets.

@thomasbrandon
Last active November 14, 2019 21:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thomasbrandon/f03aa9e1d7856ab5086f135e83a12b42 to your computer and use it in GitHub Desktop.
Save thomasbrandon/f03aa9e1d7856ab5086f135e83a12b42 to your computer and use it in GitHub Desktop.
Match layers between fastai and torchvision models
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"from fastai.vision import *\n",
"import torchvision"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"data = (ImageList.from_folder(untar_data(URLs.IMAGEWOOF_320))\n",
" .filter_by_rand(0.1).split_by_folder(valid='val')\n",
" .label_from_folder().transform(size=224).databunch(num_workers=0))"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"def iter_model_tree(model:nn.Module, pre_order:bool=True, visit_root:bool=True):\n",
" \"\"\"Iterate a model tree yielding (name,module) pairs.\"\"\"\n",
" mod_path = [model]\n",
" iter_path = [iter(model.named_children())]\n",
" name_path = []\n",
" if visit_root and pre_order: yield '',model\n",
" while mod_path:\n",
" cur_mod,cur_it = mod_path[-1],iter_path[-1]\n",
" try:\n",
" name,mod = next(cur_it)\n",
" name_path.append(name)\n",
" mod_path.append(mod)\n",
" iter_path.append(iter(mod.named_children()))\n",
" if pre_order and (visit_root or mod is not model):\n",
" yield '.'.join(name_path), mod\n",
" except StopIteration:\n",
" mod = mod_path.pop()\n",
" if mod is model and not visit_root: return\n",
" if not pre_order: yield '.'.join(name_path), mod\n",
" iter_path.pop()\n",
" if len(name_path) > 0: name_path.pop()\n",
" if visit_root and not pre_order: yield '',model"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"orig_mdl = torchvision.models.resnet34()"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"text/plain": [
"['conv1.weight',\n",
" 'bn1.weight',\n",
" 'bn1.bias',\n",
" 'bn1.running_mean',\n",
" 'bn1.running_var',\n",
" 'bn1.num_batches_tracked',\n",
" 'layer1.0.conv1.weight',\n",
" 'layer1.0.bn1.weight',\n",
" 'layer1.0.bn1.bias',\n",
" 'layer1.0.bn1.running_mean']"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(orig_mdl.state_dict().keys())[:10]"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"text/plain": [
"['0.0.weight',\n",
" '0.1.weight',\n",
" '0.1.bias',\n",
" '0.1.running_mean',\n",
" '0.1.running_var',\n",
" '0.1.num_batches_tracked',\n",
" '0.4.0.conv1.weight',\n",
" '0.4.0.bn1.weight',\n",
" '0.4.0.bn1.bias',\n",
" '0.4.0.bn1.running_mean']"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn = cnn_learner(data, models.resnet34)\n",
"list(learn.model.state_dict().keys())[:10]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"orig_layers,new_layers = (\n",
" [(n,mod)\n",
" for n,mod in iter_model_tree(mdl, visit_root=False)\n",
" if isinstance(mod,(nn.Conv2d,nn.BatchNorm2d))\n",
" ] for mdl in (orig_mdl,learn.model))"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"text/plain": [
"[('conv1', '0.0'),\n",
" ('bn1', '0.1'),\n",
" ('layer1.0.conv1', '0.4.0.conv1'),\n",
" ('layer1.0.bn1', '0.4.0.bn1'),\n",
" ('layer1.0.conv2', '0.4.0.conv2'),\n",
" ('layer1.0.bn2', '0.4.0.bn2'),\n",
" ('layer1.1.conv1', '0.4.1.conv1'),\n",
" ('layer1.1.bn1', '0.4.1.bn1'),\n",
" ('layer1.1.conv2', '0.4.1.conv2'),\n",
" ('layer1.1.bn2', '0.4.1.bn2')]"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mapping = {n1:n2 for (n1,m1),(n2,m2) in zip(orig_layers,new_layers)}\n",
"list(mapping.items())[:10]"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"orig_state = torch.utils.model_zoo.load_url(torchvision.models.resnet.model_urls['resnet34'])"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {
"Collapsed": "false"
},
"outputs": [],
"source": [
"new_state = {mapping[n[:p]]+n[p:]: s\n",
" for n,p,s in ((n,n.rfind('.'),s) for n,s in orig_state.items())\n",
" if n[:p] in mapping}"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {
"Collapsed": "false"
},
"outputs": [
{
"data": {
"text/plain": [
"_IncompatibleKeys(missing_keys=['1.2.weight', '1.2.bias', '1.2.running_mean', '1.2.running_var', '1.4.weight', '1.4.bias', '1.6.weight', '1.6.bias', '1.6.running_mean', '1.6.running_var', '1.8.weight', '1.8.bias'], unexpected_keys=[])"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"learn.model.load_state_dict(new_state, strict=False)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "fastai-dev-venv",
"language": "python",
"name": "fastai-dev-venv"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment