Skip to content

Instantly share code, notes, and snippets.

@ltiao
Created July 28, 2022 14:41
Show Gist options
  • Save ltiao/d0a07a82d2ac25f800fbc081f67b40e6 to your computer and use it in GitHub Desktop.
Save ltiao/d0a07a82d2ac25f800fbc081f67b40e6 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "3204663b",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"from transformers import BertModel, BertConfig"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "cf11931b",
"metadata": {},
"outputs": [],
"source": [
"config_names = ['tiny', 'small', 'base', 'large']\n",
"\n",
"dff_ratio = 4\n",
"n_head_ratio = 64"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a6c5612d",
"metadata": {},
"outputs": [],
"source": [
"num_hidden_layers = dict(\n",
" tiny=4, \n",
" small=6,\n",
" base=12,\n",
" large=24\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e06d7541",
"metadata": {},
"outputs": [],
"source": [
"hidden_size = dict(\n",
" tiny=256, \n",
" small=512,\n",
" base=768,\n",
" large=1024\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "29a4c8eb",
"metadata": {},
"outputs": [],
"source": [
"num_attention_heads = dict(\n",
" tiny=4, \n",
" small=8,\n",
" base=12,\n",
" large=16\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f8d0370a",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"dict_keys(['tiny', 'small', 'base', 'large'])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"hidden_size.keys()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "6b53fbae",
"metadata": {},
"outputs": [],
"source": [
"def config_overrides_fn(dct):\n",
" return ','.join(f\"{k}={v}\" for k, v in dct.items())"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "a8726d3c",
"metadata": {},
"outputs": [],
"source": [
"dcts = {}\n",
"for config_name in config_names:\n",
" kwargs = dict(\n",
" hidden_size=hidden_size[config_name],\n",
" intermediate_size=dff_ratio * hidden_size[config_name],\n",
" num_attention_heads=num_attention_heads[config_name],\n",
" num_hidden_layers=num_hidden_layers[config_name]\n",
" )\n",
" config = BertConfig(**kwargs)\n",
" model = BertModel(config)\n",
"\n",
" dcts[config_name] = config.to_dict()\n",
" dcts[config_name][\"config_overrides\"] = config_overrides_fn(kwargs)\n",
" dcts[config_name][\"num_parameters\"] = model.num_parameters()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "22bc5156",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>num_hidden_layers</th>\n",
" <th>num_attention_heads</th>\n",
" <th>intermediate_size</th>\n",
" <th>hidden_size</th>\n",
" <th>num_parameters</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>tiny</th>\n",
" <td>4</td>\n",
" <td>4</td>\n",
" <td>1024</td>\n",
" <td>256</td>\n",
" <td>11170560</td>\n",
" </tr>\n",
" <tr>\n",
" <th>small</th>\n",
" <td>6</td>\n",
" <td>8</td>\n",
" <td>2048</td>\n",
" <td>512</td>\n",
" <td>35068416</td>\n",
" </tr>\n",
" <tr>\n",
" <th>base</th>\n",
" <td>12</td>\n",
" <td>12</td>\n",
" <td>3072</td>\n",
" <td>768</td>\n",
" <td>109482240</td>\n",
" </tr>\n",
" <tr>\n",
" <th>large</th>\n",
" <td>24</td>\n",
" <td>16</td>\n",
" <td>4096</td>\n",
" <td>1024</td>\n",
" <td>335141888</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" num_hidden_layers num_attention_heads intermediate_size hidden_size \\\n",
"tiny 4 4 1024 256 \n",
"small 6 8 2048 512 \n",
"base 12 12 3072 768 \n",
"large 24 16 4096 1024 \n",
"\n",
" num_parameters \n",
"tiny 11170560 \n",
"small 35068416 \n",
"base 109482240 \n",
"large 335141888 "
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"frame = pd.DataFrame.from_dict(data=dcts, orient=\"index\")\n",
"frame[[\"num_hidden_layers\", \"num_attention_heads\", \"intermediate_size\", \"hidden_size\", \"num_parameters\"]]"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment