Created
July 28, 2022 14:41
-
-
Save ltiao/d0a07a82d2ac25f800fbc081f67b40e6 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "3204663b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pandas as pd\n", | |
"\n", | |
"from transformers import BertModel, BertConfig" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "cf11931b", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"config_names = ['tiny', 'small', 'base', 'large']\n", | |
"\n", | |
"dff_ratio = 4\n", | |
"n_head_ratio = 64" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "a6c5612d", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"num_hidden_layers = dict(\n", | |
" tiny=4, \n", | |
" small=6,\n", | |
" base=12,\n", | |
" large=24\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "e06d7541", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"hidden_size = dict(\n", | |
" tiny=256, \n", | |
" small=512,\n", | |
" base=768,\n", | |
" large=1024\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"id": "29a4c8eb", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"num_attention_heads = dict(\n", | |
" tiny=4, \n", | |
" small=8,\n", | |
" base=12,\n", | |
" large=16\n", | |
")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"id": "f8d0370a", | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"dict_keys(['tiny', 'small', 'base', 'large'])" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"hidden_size.keys()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"id": "6b53fbae", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def config_overrides_fn(dct):\n", | |
" return ','.join(f\"{k}={v}\" for k, v in dct.items())" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"id": "a8726d3c", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"dcts = {}\n", | |
"for config_name in config_names:\n", | |
" kwargs = dict(\n", | |
" hidden_size=hidden_size[config_name],\n", | |
" intermediate_size=dff_ratio * hidden_size[config_name],\n", | |
" num_attention_heads=num_attention_heads[config_name],\n", | |
" num_hidden_layers=num_hidden_layers[config_name]\n", | |
" )\n", | |
" config = BertConfig(**kwargs)\n", | |
" model = BertModel(config)\n", | |
"\n", | |
" dcts[config_name] = config.to_dict()\n", | |
" dcts[config_name][\"config_overrides\"] = config_overrides_fn(kwargs)\n", | |
" dcts[config_name][\"num_parameters\"] = model.num_parameters()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"id": "22bc5156", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/html": [ | |
"<div>\n", | |
"<style scoped>\n", | |
" .dataframe tbody tr th:only-of-type {\n", | |
" vertical-align: middle;\n", | |
" }\n", | |
"\n", | |
" .dataframe tbody tr th {\n", | |
" vertical-align: top;\n", | |
" }\n", | |
"\n", | |
" .dataframe thead th {\n", | |
" text-align: right;\n", | |
" }\n", | |
"</style>\n", | |
"<table border=\"1\" class=\"dataframe\">\n", | |
" <thead>\n", | |
" <tr style=\"text-align: right;\">\n", | |
" <th></th>\n", | |
" <th>num_hidden_layers</th>\n", | |
" <th>num_attention_heads</th>\n", | |
" <th>intermediate_size</th>\n", | |
" <th>hidden_size</th>\n", | |
" <th>num_parameters</th>\n", | |
" </tr>\n", | |
" </thead>\n", | |
" <tbody>\n", | |
" <tr>\n", | |
" <th>tiny</th>\n", | |
" <td>4</td>\n", | |
" <td>4</td>\n", | |
" <td>1024</td>\n", | |
" <td>256</td>\n", | |
" <td>11170560</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>small</th>\n", | |
" <td>6</td>\n", | |
" <td>8</td>\n", | |
" <td>2048</td>\n", | |
" <td>512</td>\n", | |
" <td>35068416</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>base</th>\n", | |
" <td>12</td>\n", | |
" <td>12</td>\n", | |
" <td>3072</td>\n", | |
" <td>768</td>\n", | |
" <td>109482240</td>\n", | |
" </tr>\n", | |
" <tr>\n", | |
" <th>large</th>\n", | |
" <td>24</td>\n", | |
" <td>16</td>\n", | |
" <td>4096</td>\n", | |
" <td>1024</td>\n", | |
" <td>335141888</td>\n", | |
" </tr>\n", | |
" </tbody>\n", | |
"</table>\n", | |
"</div>" | |
], | |
"text/plain": [ | |
" num_hidden_layers num_attention_heads intermediate_size hidden_size \\\n", | |
"tiny 4 4 1024 256 \n", | |
"small 6 8 2048 512 \n", | |
"base 12 12 3072 768 \n", | |
"large 24 16 4096 1024 \n", | |
"\n", | |
" num_parameters \n", | |
"tiny 11170560 \n", | |
"small 35068416 \n", | |
"base 109482240 \n", | |
"large 335141888 " | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"frame = pd.DataFrame.from_dict(data=dcts, orient=\"index\")\n", | |
"frame[[\"num_hidden_layers\", \"num_attention_heads\", \"intermediate_size\", \"hidden_size\", \"num_parameters\"]]" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment