Skip to content

Instantly share code, notes, and snippets.

@noklam
Created November 25, 2018 08:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save noklam/68dda5c92d1b8f2033aa6222f1e0cfad to your computer and use it in GitHub Desktop.
Save noklam/68dda5c92d1b8f2033aa6222f1e0cfad to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<module 'fastai' from '/home/nok/learning/fastai_fork/fastai-add-more-support-for-model-summary/tests/fastai/__init__.py'>"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from fastai import * # Quick access to most common functionality\n",
"from fastai.text import * # Quick access to NLP functionality\n",
"%load_ext autoreload\n",
"%autoreload 2\n",
"from fastai.callbacks import *\n",
"import fastai; fastai"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Work Example"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input Size override by Learner.data.train_dl\n",
"Input Size passed in: 64 \n",
"\n",
"====================================================================================================\n",
"Layer (type) Output Shape Param # \n",
"====================================================================================================\n",
"Conv2d [64, 64, 14, 14] 9408 \n",
"____________________________________________________________________________________________________\n",
"BatchNorm2d [64, 64, 14, 14] 128 \n",
"____________________________________________________________________________________________________\n",
"ReLU [64, 64, 14, 14] 0 \n",
"____________________________________________________________________________________________________\n",
"MaxPool2d [64, 64, 7, 7] 0 \n",
"____________________________________________________________________________________________________\n",
"Conv2d [64, 64, 7, 7] 36864 \n",
"____________________________________________________________________________________________________\n",
"BatchNorm2d [64, 64, 7, 7] 128 \n",
"____________________________________________________________________________________________________\n",
"ReLU [64, 64, 7, 7] 0 \n",
"____________________________________________________________________________________________________\n",
"Conv2d [64, 64, 7, 7] 36864 \n",
"____________________________________________________________________________________________________\n",
"BatchNorm2d [64, 64, 7, 7] 128 \n",
"____________________________________________________________________________________________________\n",
"Conv2d [64, 64, 7, 7] 36864 \n",
"____________________________________________________________________________________________________\n",
"BatchNorm2d [64, 64, 7, 7] 128 \n",
"____________________________________________________________________________________________________\n",
"ReLU [64, 64, 7, 7] 0 \n",
"____________________________________________________________________________________________________\n",
"Conv2d [64, 64, 7, 7] 36864 \n",
"____________________________________________________________________________________________________\n",
"BatchNorm2d [64, 64, 7, 7] 128 \n",
"____________________________________________________________________________________________________\n",
"Conv2d [64, 128, 4, 4] 73728 \n",
"____________________________________________________________________________________________________\n",
"BatchNorm2d [64, 128, 4, 4] 256 \n",
"____________________________________________________________________________________________________\n",
"ReLU [64, 128, 4, 4] 0 \n",
"____________________________________________________________________________________________________\n",
"Conv2d [64, 128, 4, 4] 147456 \n",
"____________________________________________________________________________________________________\n",
"BatchNorm2d [64, 128, 4, 4] 256 \n",
"____________________________________________________________________________________________________\n",
"Conv2d [64, 128, 4, 4] 8192 \n",
"____________________________________________________________________________________________________\n",
"BatchNorm2d [64, 128, 4, 4] 256 \n",
"____________________________________________________________________________________________________\n",
"Conv2d [64, 128, 4, 4] 147456 \n",
"____________________________________________________________________________________________________\n",
"BatchNorm2d [64, 128, 4, 4] 256 \n",
"____________________________________________________________________________________________________\n",
"ReLU [64, 128, 4, 4] 0 \n",
"____________________________________________________________________________________________________\n",
"Conv2d [64, 128, 4, 4] 147456 \n",
"____________________________________________________________________________________________________\n",
"BatchNorm2d [64, 128, 4, 4] 256 \n",
"____________________________________________________________________________________________________\n",
"Conv2d [64, 256, 2, 2] 294912 \n",
"____________________________________________________________________________________________________\n",
"BatchNorm2d [64, 256, 2, 2] 512 \n",
"____________________________________________________________________________________________________\n",
"ReLU [64, 256, 2, 2] 0 \n",
"____________________________________________________________________________________________________\n",
"Conv2d [64, 256, 2, 2] 589824 \n",
"____________________________________________________________________________________________________\n",
"BatchNorm2d [64, 256, 2, 2] 512 \n",
"____________________________________________________________________________________________________\n",
"Conv2d [64, 256, 2, 2] 32768 \n",
"____________________________________________________________________________________________________\n",
"BatchNorm2d [64, 256, 2, 2] 512 \n",
"____________________________________________________________________________________________________\n",
"Conv2d [64, 256, 2, 2] 589824 \n",
"____________________________________________________________________________________________________\n",
"BatchNorm2d [64, 256, 2, 2] 512 \n",
"____________________________________________________________________________________________________\n",
"ReLU [64, 256, 2, 2] 0 \n",
"____________________________________________________________________________________________________\n",
"Conv2d [64, 256, 2, 2] 589824 \n",
"____________________________________________________________________________________________________\n",
"BatchNorm2d [64, 256, 2, 2] 512 \n",
"____________________________________________________________________________________________________\n",
"Conv2d [64, 512, 1, 1] 1179648 \n",
"____________________________________________________________________________________________________\n",
"BatchNorm2d [64, 512, 1, 1] 1024 \n",
"____________________________________________________________________________________________________\n",
"ReLU [64, 512, 1, 1] 0 \n",
"____________________________________________________________________________________________________\n",
"Conv2d [64, 512, 1, 1] 2359296 \n",
"____________________________________________________________________________________________________\n",
"BatchNorm2d [64, 512, 1, 1] 1024 \n",
"____________________________________________________________________________________________________\n",
"Conv2d [64, 512, 1, 1] 131072 \n",
"____________________________________________________________________________________________________\n",
"BatchNorm2d [64, 512, 1, 1] 1024 \n",
"____________________________________________________________________________________________________\n",
"Conv2d [64, 512, 1, 1] 2359296 \n",
"____________________________________________________________________________________________________\n",
"BatchNorm2d [64, 512, 1, 1] 1024 \n",
"____________________________________________________________________________________________________\n",
"ReLU [64, 512, 1, 1] 0 \n",
"____________________________________________________________________________________________________\n",
"Conv2d [64, 512, 1, 1] 2359296 \n",
"____________________________________________________________________________________________________\n",
"BatchNorm2d [64, 512, 1, 1] 1024 \n",
"____________________________________________________________________________________________________\n",
"AdaptiveAvgPool2d [64, 512, 1, 1] 0 \n",
"____________________________________________________________________________________________________\n",
"AdaptiveMaxPool2d [64, 512, 1, 1] 0 \n",
"____________________________________________________________________________________________________\n",
"Lambda [64, 1024] 0 \n",
"____________________________________________________________________________________________________\n",
"BatchNorm1d [64, 1024] 2048 \n",
"____________________________________________________________________________________________________\n",
"Dropout [64, 1024] 0 \n",
"____________________________________________________________________________________________________\n",
"Linear [64, 512] 524800 \n",
"____________________________________________________________________________________________________\n",
"ReLU [64, 512] 0 \n",
"____________________________________________________________________________________________________\n",
"BatchNorm1d [64, 512] 1024 \n",
"____________________________________________________________________________________________________\n",
"Dropout [64, 512] 0 \n",
"____________________________________________________________________________________________________\n",
"Linear [64, 2] 1026 \n",
"____________________________________________________________________________________________________\n",
"Total params: 11705410\n"
]
}
],
"source": [
"from fastai.vision import *\n",
"path = untar_data(URLs.MNIST_TINY)\n",
"data = ImageDataBunch.from_folder(path, ds_tfms=(rand_pad(2, 28), []), bs=64)\n",
"learn = create_cnn(data, models.resnet18, metrics=accuracy)\n",
"model_summary(learn)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Not Work for Text"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Input Size override by Learner.data.train_dl\n",
"Input Size passed in: 95 \n",
"\n"
]
},
{
"ename": "AttributeError",
"evalue": "'NoneType' object has no attribute 'shape'",
"output_type": "error",
"traceback": [
"\u001b[0;31m-------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-5-cd498f9c27c3>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mdata_lm\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTextLMDataBunch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_csv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'texts.csv'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mlearn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlanguage_model_learner\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_lm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpretrained_model\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mmodel_summary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlearn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/learning/fastai_fork/fastai-add-more-support-for-model-summary/tests/fastai/callbacks/hooks.py\u001b[0m in \u001b[0;36mmodel_summary\u001b[0;34m(m, n)\u001b[0m\n\u001b[1;32m 146\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mmodel_summary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mCollection\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 147\u001b[0m \u001b[0;34m\"Print a summary of `m` using a char length of `n`.\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 148\u001b[0;31m \u001b[0minfo\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlayers_info\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 149\u001b[0m \u001b[0mheader\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m\"Layer (type)\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"Output Shape\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"Param #\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"=\"\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/learning/fastai_fork/fastai-add-more-support-for-model-summary/tests/fastai/callbacks/hooks.py\u001b[0m in \u001b[0;36mlayers_info\u001b[0;34m(m)\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[0mfunc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mget_layer_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mflatten_model\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 141\u001b[0m \u001b[0mlayers_names\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mLearner\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 142\u001b[0;31m \u001b[0mlayers_sizes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayers_params\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparams_size\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 143\u001b[0m \u001b[0mlayer_info\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnamedtuple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Layer_Information'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'Layer'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'OutputSize'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'Params'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 144\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlayer_info\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayers_names\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayers_sizes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlayers_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/learning/fastai_fork/fastai-add-more-support-for-model-summary/tests/fastai/callbacks/hooks.py\u001b[0m in \u001b[0;36mparams_size\u001b[0;34m(m, size)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_listy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0mhooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhooks_outputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhooks_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 132\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstored\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mo\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstored\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mo\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mhooks\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 133\u001b[0m \u001b[0moutput_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mres\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0moutput_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhooks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/learning/fastai_fork/fastai-add-more-support-for-model-summary/tests/fastai/callbacks/hooks.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mis_listy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[0mhooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhooks_outputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhooks_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 132\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstored\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mo\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstored\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mo\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mhooks\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 133\u001b[0m \u001b[0moutput_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mres\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0moutput_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhooks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAttributeError\u001b[0m: 'NoneType' object has no attribute 'shape'"
]
}
],
"source": [
"path = untar_data(URLs.IMDB_SAMPLE)\n",
"data_lm = TextLMDataBunch.from_csv(path, 'texts.csv')\n",
"learn = language_model_learner(data_lm, pretrained_model=None)\n",
"model_summary(learn)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"m = learn"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\u001b[0;31mSignature:\u001b[0m \u001b[0mhook_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mCollection\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodules\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mfastai\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcallbacks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mHooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mDocstring:\u001b[0m <no docstring>\n",
"\u001b[0;31mSource:\u001b[0m \n",
"\u001b[0;32mdef\u001b[0m \u001b[0mhook_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mCollection\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mHooks\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mHooks\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodules\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mo\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtotal_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mFile:\u001b[0m ~/learning/fastai_fork/fastai-add-more-support-for-model-summary/tests/fastai/callbacks/hooks.py\n",
"\u001b[0;31mType:\u001b[0m function\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from fastai.callbacks.hooks import hook_params, total_params\n",
"??hook_params"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\u001b[0;31mSignature:\u001b[0m \u001b[0mtotal_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodules\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mDocstring:\u001b[0m <no docstring>\n",
"\u001b[0;31mSource:\u001b[0m \n",
"\u001b[0;32mdef\u001b[0m \u001b[0mtotal_params\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mModule\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"weight\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"size\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"bias\"\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"size\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhasattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mm\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"all_weights\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mparams\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mall_weights\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;31m# print('Params', params, 'layer:', m)\u001b[0m\u001b[0;34m\u001b[0m\n",
"\u001b[0;34m\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mFile:\u001b[0m ~/learning/fastai_fork/fastai-add-more-support-for-model-summary/tests/fastai/callbacks/hooks.py\n",
"\u001b[0;31mType:\u001b[0m function\n"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"??total_params"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Layer: Embedding(6144, 400, padding_idx=1) None Nothing Stored\n",
"Layer: Embedding(6144, 400, padding_idx=1) None Nothing Stored\n",
"Layer: LSTM(400, 1150) None Nothing Stored\n",
"Layer: LSTM(1150, 1150) None Nothing Stored\n",
"Layer: LSTM(1150, 400) None Nothing Stored\n",
"Layer: RNNDropout() 0 torch.Size([69, 64, 400])\n",
"Layer: RNNDropout() 0 torch.Size([69, 64, 1150])\n",
"Layer: RNNDropout() 0 torch.Size([69, 64, 1150])\n",
"Layer: RNNDropout() None Nothing Stored\n",
"Layer: Linear(in_features=400, out_features=6144, bias=True) 2463744 torch.Size([4416, 6144])\n",
"Layer: RNNDropout() 0 torch.Size([69, 64, 400])\n"
]
}
],
"source": [
"d= m.data.one_batch()\n",
"hooks_outputs = hook_outputs(flatten_model(m.model))\n",
"hooks_params = hook_params(flatten_model(m.model))\n",
"hooks = zip(hooks_outputs, hooks_params)\n",
"m.model.eval()(d[0])\n",
"for o ,layer in zip(zip(hooks_outputs, hooks_params), flatten_model(m.model)):\n",
" try:\n",
" print('Layer:', layer, end=' ')\n",
"# print('Weight Shape:', layer.weight.numel(), end=' ')\n",
"# print('Weight Shape:', layer.bias_hh_l0.numel(), end=' ')\n",
" print(o[1].stored, end=' ')\n",
" print(o[0].stored.shape)\n",
" except:\n",
" print('Nothing Stored')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment