Created
November 25, 2018 08:29
-
-
Save noklam/68dda5c92d1b8f2033aa6222f1e0cfad to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"<module 'fastai' from '/home/nok/learning/fastai_fork/fastai-add-more-support-for-model-summary/tests/fastai/__init__.py'>" | |
] | |
}, | |
"execution_count": 1, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"from fastai import * # Quick access to most common functionality\n", | |
"from fastai.text import * # Quick access to NLP functionality\n", | |
"%load_ext autoreload\n", | |
"%autoreload 2\n", | |
"from fastai.callbacks import *\n", | |
"import fastai; fastai" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Work Example" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Input Size override by Learner.data.train_dl\n", | |
"Input Size passed in: 64 \n", | |
"\n", | |
"====================================================================================================\n", | |
"Layer (type) Output Shape Param # \n", | |
"====================================================================================================\n", | |
"Conv2d [64, 64, 14, 14] 9408 \n", | |
"____________________________________________________________________________________________________\n", | |
"BatchNorm2d [64, 64, 14, 14] 128 \n", | |
"____________________________________________________________________________________________________\n", | |
"ReLU [64, 64, 14, 14] 0 \n", | |
"____________________________________________________________________________________________________\n", | |
"MaxPool2d [64, 64, 7, 7] 0 \n", | |
"____________________________________________________________________________________________________\n", | |
"Conv2d [64, 64, 7, 7] 36864 \n", | |
"____________________________________________________________________________________________________\n", | |
"BatchNorm2d [64, 64, 7, 7] 128 \n", | |
"____________________________________________________________________________________________________\n", | |
"ReLU [64, 64, 7, 7] 0 \n", | |
"____________________________________________________________________________________________________\n", | |
"Conv2d [64, 64, 7, 7] 36864 \n", | |
"____________________________________________________________________________________________________\n", | |
"BatchNorm2d [64, 64, 7, 7] 128 \n", | |
"____________________________________________________________________________________________________\n", | |
"Conv2d [64, 64, 7, 7] 36864 \n", | |
"____________________________________________________________________________________________________\n", | |
"BatchNorm2d [64, 64, 7, 7] 128 \n", | |
"____________________________________________________________________________________________________\n", | |
"ReLU [64, 64, 7, 7] 0 \n", | |
"____________________________________________________________________________________________________\n", | |
"Conv2d [64, 64, 7, 7] 36864 \n", | |
"____________________________________________________________________________________________________\n", | |
"BatchNorm2d [64, 64, 7, 7] 128 \n", | |
"____________________________________________________________________________________________________\n", | |
"Conv2d [64, 128, 4, 4] 73728 \n", | |
"____________________________________________________________________________________________________\n", | |
"BatchNorm2d [64, 128, 4, 4] 256 \n", | |
"____________________________________________________________________________________________________\n", | |
"ReLU [64, 128, 4, 4] 0 \n", | |
"____________________________________________________________________________________________________\n", | |
"Conv2d [64, 128, 4, 4] 147456 \n", | |
"____________________________________________________________________________________________________\n", | |
"BatchNorm2d [64, 128, 4, 4] 256 \n", | |
"____________________________________________________________________________________________________\n", | |
"Conv2d [64, 128, 4, 4] 8192 \n", | |
"____________________________________________________________________________________________________\n", | |
"BatchNorm2d [64, 128, 4, 4] 256 \n", | |
"____________________________________________________________________________________________________\n", | |
"Conv2d [64, 128, 4, 4] 147456 \n", | |
"____________________________________________________________________________________________________\n", | |
"BatchNorm2d [64, 128, 4, 4] 256 \n", | |
"____________________________________________________________________________________________________\n", | |
"ReLU [64, 128, 4, 4] 0 \n", | |
"____________________________________________________________________________________________________\n", | |
"Conv2d [64, 128, 4, 4] 147456 \n", | |
"____________________________________________________________________________________________________\n", | |
"BatchNorm2d [64, 128, 4, 4] 256 \n", | |
"____________________________________________________________________________________________________\n", | |
"Conv2d [64, 256, 2, 2] 294912 \n", | |
"____________________________________________________________________________________________________\n", | |
"BatchNorm2d [64, 256, 2, 2] 512 \n", | |
"____________________________________________________________________________________________________\n", | |
"ReLU [64, 256, 2, 2] 0 \n", | |
"____________________________________________________________________________________________________\n", | |
"Conv2d [64, 256, 2, 2] 589824 \n", | |
"____________________________________________________________________________________________________\n", | |
"BatchNorm2d [64, 256, 2, 2] 512 \n", | |
"____________________________________________________________________________________________________\n", | |
"Conv2d [64, 256, 2, 2] 32768 \n", | |
"____________________________________________________________________________________________________\n", | |
"BatchNorm2d [64, 256, 2, 2] 512 \n", | |
"____________________________________________________________________________________________________\n", | |
"Conv2d [64, 256, 2, 2] 589824 \n", | |
"____________________________________________________________________________________________________\n", | |
"BatchNorm2d [64, 256, 2, 2] 512 \n", | |
"____________________________________________________________________________________________________\n", | |
"ReLU [64, 256, 2, 2] 0 \n", | |
"____________________________________________________________________________________________________\n", | |
"Conv2d [64, 256, 2, 2] 589824 \n", | |
"____________________________________________________________________________________________________\n", | |
"BatchNorm2d [64, 256, 2, 2] 512 \n", | |
"____________________________________________________________________________________________________\n", | |
"Conv2d [64, 512, 1, 1] 1179648 \n", | |
"____________________________________________________________________________________________________\n", | |
"BatchNorm2d [64, 512, 1, 1] 1024 \n", | |
"____________________________________________________________________________________________________\n", | |
"ReLU [64, 512, 1, 1] 0 \n", | |
"____________________________________________________________________________________________________\n", | |
"Conv2d [64, 512, 1, 1] 2359296 \n", | |
"____________________________________________________________________________________________________\n", | |
"BatchNorm2d [64, 512, 1, 1] 1024 \n", | |
"____________________________________________________________________________________________________\n", | |
"Conv2d [64, 512, 1, 1] 131072 \n", | |
"____________________________________________________________________________________________________\n", | |
"BatchNorm2d [64, 512, 1, 1] 1024 \n", | |
"____________________________________________________________________________________________________\n", | |
"Conv2d [64, 512, 1, 1] 2359296 \n", | |
"____________________________________________________________________________________________________\n", | |
"BatchNorm2d [64, 512, 1, 1] 1024 \n", | |
"____________________________________________________________________________________________________\n", | |
"ReLU [64, 512, 1, 1] 0 \n", | |
"____________________________________________________________________________________________________\n", | |
"Conv2d [64, 512, 1, 1] 2359296 \n", | |
"____________________________________________________________________________________________________\n", | |
"BatchNorm2d [64, 512, 1, 1] 1024 \n", | |
"____________________________________________________________________________________________________\n", | |
"AdaptiveAvgPool2d [64, 512, 1, 1] 0 \n", | |
"____________________________________________________________________________________________________\n", | |
"AdaptiveMaxPool2d [64, 512, 1, 1] 0 \n", | |
"____________________________________________________________________________________________________\n", | |
"Lambda [64, 1024] 0 \n", | |
"____________________________________________________________________________________________________\n", | |
"BatchNorm1d [64, 1024] 2048 \n", | |
"____________________________________________________________________________________________________\n", | |
"Dropout [64, 1024] 0 \n", | |
"____________________________________________________________________________________________________\n", | |
"Linear [64, 512] 524800 \n", | |
"____________________________________________________________________________________________________\n", | |
"ReLU [64, 512] 0 \n", | |
"____________________________________________________________________________________________________\n", | |
"BatchNorm1d [64, 512] 1024 \n", | |
"____________________________________________________________________________________________________\n", | |
"Dropout [64, 512] 0 \n", | |
"____________________________________________________________________________________________________\n", | |
"Linear [64, 2] 1026 \n", | |
"____________________________________________________________________________________________________\n", | |
"Total params: 11705410\n" | |
] | |
} | |
], | |
"source": [ | |
"from fastai.vision import *\n", | |
"path = untar_data(URLs.MNIST_TINY)\n", | |
"data = ImageDataBunch.from_folder(path, ds_tfms=(rand_pad(2, 28), []), bs=64)\n", | |
"learn = create_cnn(data, models.resnet18, metrics=accuracy)\n", | |
"model_summary(learn)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Not Work for Text" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Input Size override by Learner.data.train_dl\n", | |
"Input Size passed in: 95 \n", | |
"\n" | |
] | |
}, | |
{ | |
"ename": "AttributeError", | |
"evalue": "'NoneType' object has no attribute 'shape'", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m-------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", | |
"\u001b[0;32m<ipython-input-5-cd498f9c27c3>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mdata_lm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTextLMDataBunch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_csv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'texts.csv'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mlearn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlanguage_model_learner\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_lm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpretrained_model\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mmodel_summary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlearn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | |
"\u001b[0;32m~/learning/fastai_fork/fastai-add-more-support-for-model-summary/tests/fastai/callbacks/hooks.py\u001b[0m in \u001b[0;36mmodel_summary\u001b[0;34m(m, n)\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmodel_summary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mCollection\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0;34m\"Print a summary of `m` using a char length of `n`.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 148\u001b[0;31m \u001b[0minfo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlayers_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 149\u001b[0m \u001b[0mheader\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m\"Layer (type)\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"Output Shape\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"Param #\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"=\"\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/learning/fastai_fork/fastai-add-more-support-for-model-summary/tests/fastai/callbacks/hooks.py\u001b[0m in \u001b[0;36mlayers_info\u001b[0;34m(m)\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[0mfunc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mget_layer_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflatten_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0mlayers_names\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLearner\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 142\u001b[0;31m \u001b[0mlayers_sizes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayers_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparams_size\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 143\u001b[0m \u001b[0mlayer_info\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnamedtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Layer_Information'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'Layer'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'OutputSize'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'Params'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlayer_info\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayers_names\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayers_sizes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayers_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/learning/fastai_fork/fastai-add-more-support-for-model-summary/tests/fastai/callbacks/hooks.py\u001b[0m in \u001b[0;36mparams_size\u001b[0;34m(m, size)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_listy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0mhooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhooks_outputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhooks_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 132\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstored\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mo\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstored\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mo\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mhooks\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 133\u001b[0m \u001b[0moutput_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mres\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0moutput_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhooks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;32m~/learning/fastai_fork/fastai-add-more-support-for-model-summary/tests/fastai/callbacks/hooks.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_listy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0mhooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhooks_outputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhooks_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 132\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstored\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mo\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstored\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mo\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mhooks\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 133\u001b[0m \u001b[0moutput_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mres\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0moutput_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhooks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'shape'" | |
] | |
} | |
], | |
"source": [ | |
"path = untar_data(URLs.IMDB_SAMPLE)\n", | |
"data_lm = TextLMDataBunch.from_csv(path, 'texts.csv')\n", | |
"learn = language_model_learner(data_lm, pretrained_model=None)\n", | |
"model_summary(learn)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"m = learn" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"\u001b[0;31mSignature:\u001b[0m \u001b[0mhook_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mCollection\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodules\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mfastai\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mHooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mDocstring:\u001b[0m <no docstring>\n", | |
"\u001b[0;31mSource:\u001b[0m \n", | |
"\u001b[0;32mdef\u001b[0m \u001b[0mhook_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mCollection\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mHooks\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", | |
"\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mHooks\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodules\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mo\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtotal_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mFile:\u001b[0m ~/learning/fastai_fork/fastai-add-more-support-for-model-summary/tests/fastai/callbacks/hooks.py\n", | |
"\u001b[0;31mType:\u001b[0m function\n" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"from fastai.callbacks.hooks import hook_params, total_params\n", | |
"??hook_params" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"\u001b[0;31mSignature:\u001b[0m \u001b[0mtotal_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodules\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mDocstring:\u001b[0m <no docstring>\n", | |
"\u001b[0;31mSource:\u001b[0m \n", | |
"\u001b[0;32mdef\u001b[0m \u001b[0mtotal_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", | |
"\u001b[0;34m\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\n", | |
"\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"weight\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"size\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", | |
"\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"bias\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"size\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", | |
"\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"all_weights\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mall_weights\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", | |
"\u001b[0;34m\u001b[0m \u001b[0;31m# print('Params', params, 'layer:', m)\u001b[0m\u001b[0;34m\u001b[0m\n", | |
"\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", | |
"\u001b[0;31mFile:\u001b[0m ~/learning/fastai_fork/fastai-add-more-support-for-model-summary/tests/fastai/callbacks/hooks.py\n", | |
"\u001b[0;31mType:\u001b[0m function\n" | |
] | |
}, | |
"metadata": {}, | |
"output_type": "display_data" | |
} | |
], | |
"source": [ | |
"??total_params" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Layer: Embedding(6144, 400, padding_idx=1) None Nothing Stored\n", | |
"Layer: Embedding(6144, 400, padding_idx=1) None Nothing Stored\n", | |
"Layer: LSTM(400, 1150) None Nothing Stored\n", | |
"Layer: LSTM(1150, 1150) None Nothing Stored\n", | |
"Layer: LSTM(1150, 400) None Nothing Stored\n", | |
"Layer: RNNDropout() 0 torch.Size([69, 64, 400])\n", | |
"Layer: RNNDropout() 0 torch.Size([69, 64, 1150])\n", | |
"Layer: RNNDropout() 0 torch.Size([69, 64, 1150])\n", | |
"Layer: RNNDropout() None Nothing Stored\n", | |
"Layer: Linear(in_features=400, out_features=6144, bias=True) 2463744 torch.Size([4416, 6144])\n", | |
"Layer: RNNDropout() 0 torch.Size([69, 64, 400])\n" | |
] | |
} | |
], | |
"source": [ | |
"d= m.data.one_batch()\n", | |
"hooks_outputs = hook_outputs(flatten_model(m.model))\n", | |
"hooks_params = hook_params(flatten_model(m.model))\n", | |
"hooks = zip(hooks_outputs, hooks_params)\n", | |
"m.model.eval()(d[0])\n", | |
"for o ,layer in zip(zip(hooks_outputs, hooks_params), flatten_model(m.model)):\n", | |
" try:\n", | |
" print('Layer:', layer, end=' ')\n", | |
"# print('Weight Shape:', layer.weight.numel(), end=' ')\n", | |
"# print('Weight Shape:', layer.bias_hh_l0.numel(), end=' ')\n", | |
" print(o[1].stored, end=' ')\n", | |
" print(o[0].stored.shape)\n", | |
" except:\n", | |
" print('Nothing Stored')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.0" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment