Skip to content

Instantly share code, notes, and snippets.

@una-dinosauria
Created November 4, 2020 03:33
Show Gist options
  • Save una-dinosauria/e528b91de3ca9ab108cbf00aba3d9c2a to your computer and use it in GitHub Desktop.
Save una-dinosauria/e528b91de3ca9ab108cbf00aba3d9c2a to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"resnet50_small_blocks.pth\n",
"bits: 42705152\n",
"Byts: 5338144.0\n",
" KB: 5213.03125\n",
" MB: 5.09\n",
"\n",
"resnet50_large_blocks.pth\n",
"bits: 26710272\n",
"Byts: 3338784.0\n",
" KB: 3260.53125\n",
" MB: 3.18\n",
"\n",
"resnet18_small_blocks.pth\n",
"bits: 12927232\n",
"Byts: 1615904.0\n",
" KB: 1578.03125\n",
" MB: 1.54\n",
"\n",
"resnet50_semisup_small_blocks.pth\n",
"bits: 43655424\n",
"Byts: 5456928.0\n",
" KB: 5329.03125\n",
" MB: 5.20\n",
"\n",
"resnet18_large_blocks.pth\n",
"bits: 8634624\n",
"Byts: 1079328.0\n",
" KB: 1054.03125\n",
" MB: 1.03\n",
"\n",
"mask_r_cnn.pth\n",
"bits: 55743008\n",
"Byts: 6967876.0\n",
" KB: 6804.56640625\n",
" MB: 6.65\n"
]
}
],
"source": [
"import glob\n",
"import os\n",
"from collections import OrderedDict\n",
"from typing import Dict\n",
"\n",
"import numpy as np\n",
"import torch\n",
"\n",
"\n",
"def force_float_type(tensor: torch.Tensor, half: bool):\n",
" if half:\n",
" return tensor.half()\n",
" else:\n",
" return tensor.float()\n",
"\n",
"def get_bits_float(tensor: torch.Tensor) -> int:\n",
" \"\"\"Compute the bits taken by a float tensor\"\"\"\n",
" if tensor.dtype == torch.float16:\n",
" return 16 * np.prod(tensor.shape)\n",
" elif tensor.dtype == torch.float32:\n",
" return 32 * np.prod(tensor.shape)\n",
" else:\n",
" raise ValueError\n",
"\n",
"def get_bits_codes(codes: torch.Tensor, k: int) -> int:\n",
" \"\"\"Computes bits taken by an integer tensor\"\"\"\n",
" return int(np.prod(codes.shape) * np.log2(k))\n",
"\n",
"def print_size_nicely(model: str, size_bits: int) -> None:\n",
" \"\"\"Print the size of a model nicely to the eyes\"\"\"\n",
" print(f\"{model}\")\n",
" print(f\"bits: {size_bits}\")\n",
" print(f\"Byts: {size_bits / 8}\")\n",
" print(f\" KB: {size_bits / 8 / 1024}\")\n",
" print(f\" MB: {size_bits / 8 / 1024 / 1024:.2f}\")\n",
"\n",
"\n",
"def get_bgd_bits(model_dict: Dict, as_saved: bool = False, half_codebooks: bool = True, half_weights: bool = False) -> Dict:\n",
" \"\"\"Returns a dictionary with the bits taken by every part of a BGD model\n",
" Params:\n",
" model_dict: Model dictionary as returned by `torch.load()`\n",
" as_saved: If true, nothing is casted and everything is counted as it is saved\n",
" half_codebooks: If true, codebooks are counted as being in float16; else, they are counted as float32\n",
" half_weights: If true, weights and biases are counted as being in float16; else, they are counted as float32\n",
" \"\"\"\n",
" bits_dict = OrderedDict()\n",
" for (k, params) in model_dict.items():\n",
" if 'weight' in params:\n",
" weight = params['weight']\n",
" if not as_saved:\n",
" weight = force_float_type(weight, half_weights)\n",
" bits_dict[k + \".weight\"] = get_bits_float(weight)\n",
" \n",
" if 'bias' in params:\n",
" bias = params['bias']\n",
" if not as_saved:\n",
" bias = force_float_type(bias, half_weights)\n",
" bits_dict[k + \".bias\"] = get_bits_float(bias)\n",
" \n",
" if 'centroids' in params:\n",
" centroids = params['centroids']\n",
" if not as_saved:\n",
" centroids = force_float_type(centroids, half_codebooks)\n",
" bits_dict[k + \".codebook\"] = get_bits_float(centroids)\n",
" \n",
" # Count codes\n",
" bits_dict[k + \".codes_matrix\"] = get_bits_codes(params['assignments'], params['centroids'].shape[0])\n",
" \n",
" if 'weight' not in params and 'bias' not in params and 'centroids' not in params:\n",
" # For mask-rcnn, the biases of compressed layers are stored separately under the \"biases\" key\n",
" if k != \"biases\":\n",
" raise ValueError(k)\n",
" else:\n",
" for bk, bias in params.items():\n",
" bits_dict[bk] = get_bits_float(bias.half() if half_weights else bias.float())\n",
"\n",
" return bits_dict\n",
"\n",
"bgd_paths = glob.glob(\"*.pth\")\n",
"for bgd_path in bgd_paths:\n",
" bgd_model = torch.load(bgd_path)\n",
" bgd_bits_dict = get_bgd_bits(bgd_model, as_saved=False, half_codebooks=True, half_weights=False)\n",
" bgd_bits = sum(bgd_bits_dict.values())\n",
" print()\n",
" print_size_nicely(bgd_path, bgd_bits)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"half_codebooks: False, half_weights: False\n",
"mask_r_cnn.pth\n",
"bits: 57442848\n",
"Byts: 7180356.0\n",
" KB: 7012.06640625\n",
" MB: 6.85\n",
"\n",
"half_codebooks: False, half_weights: True\n",
"mask_r_cnn.pth\n",
"bits: 55909136\n",
"Byts: 6988642.0\n",
" KB: 6824.845703125\n",
" MB: 6.66\n",
"\n",
"half_codebooks: True, half_weights: False\n",
"mask_r_cnn.pth\n",
"bits: 55743008\n",
"Byts: 6967876.0\n",
" KB: 6804.56640625\n",
" MB: 6.65\n",
"\n",
"half_codebooks: True, half_weights: True\n",
"mask_r_cnn.pth\n",
"bits: 54209296\n",
"Byts: 6776162.0\n",
" KB: 6617.345703125\n",
" MB: 6.46\n"
]
}
],
"source": [
"bgd_model = torch.load(\"mask_r_cnn.pth\")\n",
"\n",
"for hc in [False, True]:\n",
" for hw in [False, True]:\n",
" bgd_bits_dict = get_bgd_bits(bgd_model, as_saved=False, half_codebooks=hc, half_weights=hw)\n",
" bgd_bits = sum(bgd_bits_dict.values())\n",
" print()\n",
" print(f\"half_codebooks: {hc}, half_weights: {hw}\")\n",
" print_size_nicely(bgd_path, bgd_bits)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment