Last active
January 22, 2024 04:24
-
-
Save sayakpaul/9ec9eb1e915a587184e3029c116f803b to your computer and use it in GitHub Desktop.
PoCs a utility to get the totsl parameters count of a diffusers pipeline without hefty downloads.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from huggingface_hub import ( | |
hf_hub_download, | |
list_files_info, | |
parse_safetensors_file_metadata, | |
) | |
import json | |
import argparse | |
def get_files_info(args): | |
files_info = list_files_info(args.repo_id, expand=True) | |
files_info = list(files_info) | |
return files_info | |
def get_model_index(args): | |
model_index_path = hf_hub_download( | |
repo_id=args.repo_id, filename="model_index.json" | |
) | |
with open(model_index_path) as f: | |
model_index = json.load(f) | |
return model_index | |
def get_params(args): | |
model_index = get_model_index(args) | |
components = [k for k in model_index.keys() if isinstance(model_index[k], list)] | |
repo_files_info = get_files_info(args) | |
total_params = 0 | |
for file in repo_files_info: | |
if ( | |
len(file.path.split("/")[-1].split(".")) == 2 | |
and "safetensors" in file.path | |
and any(k in file.path for k in components) | |
): | |
total_params += parse_safetensors_file_metadata( | |
repo_id=args.repo_id, filename=file.path | |
).parameter_count["F32"] | |
return total_params | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Simple example of a training script.") | |
parser.add_argument( | |
"--repo_id", | |
type=str, | |
default="runwayml/stable-diffusion-v1-5", | |
help="Repo id to count the params for.", | |
) | |
args = parser.parse_args() | |
params = get_params(args) | |
print(f"{args.repo_id} has {params / 1e9} billion parameters.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"id": "eb50220e-ea67-45d4-be7b-01f824118dcd", | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from huggingface_hub import hf_hub_download, list_files_info\n", | |
"import safetensors\n", | |
"import json" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"id": "f59a730d-361c-4a85-a49b-b989a3198343", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/sayakpaul/miniconda3/envs/diffusers/lib/python3.9/site-packages/huggingface_hub/utils/_deprecation.py:131: FutureWarning: 'list_files_info' (from 'huggingface_hub.hf_api') is deprecated and will be removed from version '0.23'. Use `list_repo_tree` and `get_paths_info` instead.\n", | |
" warnings.warn(warning_message, FutureWarning)\n" | |
] | |
}, | |
{ | |
"data": { | |
"text/plain": [ | |
"[RepoFile(path='.gitattributes', size=1548, blob_id='55d2855c5be698e0572b9f42af95f06bfd5fb002', lfs=None, last_commit={'oid': 'd8c8d262a38ad3da5e673f048db73a504f2a92a9', 'title': 'Upload v1-5-pruned.ckpt', 'date': datetime.datetime(2022, 10, 20, 12, 18, 52, tzinfo=datetime.timezone.utc)}, security={'safe': True, 'av_scan': {'virusFound': False, 'virusNames': None}, 'pickle_import_scan': None}),\n", | |
" RepoFile(path='README.md', size=14462, blob_id='103bad8a83037abcdb9d24f2b21eba8792b33a5d', lfs=None, last_commit={'oid': '889b629140e71758e1e0006e355c331a5744b4bf', 'title': 'Update README.md', 'date': datetime.datetime(2022, 12, 19, 15, 29, 28, tzinfo=datetime.timezone.utc)}, security={'safe': True, 'av_scan': {'virusFound': False, 'virusNames': None}, 'pickle_import_scan': None}),\n", | |
" RepoFile(path='feature_extractor/preprocessor_config.json', size=342, blob_id='5294955ff7801083f720b34b55d0f1f51313c5c5', lfs=None, last_commit={'oid': '7621c1d34cb8951ae5277a5d1c07c431e46abb48', 'title': 'add diffusers weights', 'date': datetime.datetime(2022, 10, 20, 9, 30, 42, tzinfo=datetime.timezone.utc)}, security={'safe': True, 'av_scan': {'virusFound': False, 'virusNames': None}, 'pickle_import_scan': None}),\n", | |
" RepoFile(path='model_index.json', size=541, blob_id='daf7e2e2dfc64fb437a2b44525667111b00cb9fc', lfs=None, last_commit={'oid': 'aa9ba505e1973ae5cd05f5aedd345178f52f8e6a', 'title': 'Fix deprecation warning by changing `CLIPFeatureExtractor` to `CLIPImageProcessor`. (#124)', 'date': datetime.datetime(2023, 5, 5, 10, 38, 18, tzinfo=datetime.timezone.utc)}, security={'safe': True, 'av_scan': {'virusFound': False, 'virusNames': None}, 'pickle_import_scan': None})]" | |
] | |
}, | |
"execution_count": 2, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"repo_id = \"runwayml/stable-diffusion-v1-5\"\n", | |
"\n", | |
"files_info = list_files_info(repo_id, expand=True)\n", | |
"files_info = list(files_info)\n", | |
"files_info[:4]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"id": "49f21f4a-d9f5-46f5-ad90-89e18cc4eb50", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"['feature_extractor',\n", | |
" 'safety_checker',\n", | |
" 'scheduler',\n", | |
" 'text_encoder',\n", | |
" 'tokenizer',\n", | |
" 'unet',\n", | |
" 'vae']" | |
] | |
}, | |
"execution_count": 3, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model_index_path = hf_hub_download(repo_id=repo_id, filename=\"model_index.json\")\n", | |
"\n", | |
"with open(model_index_path) as f:\n", | |
" model_index = json.load(f)\n", | |
"\n", | |
"components = [k for k in model_index.keys() if isinstance(model_index[k], list)]\n", | |
"components" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"id": "0b3f9f8c-50ef-4099-a7a8-e2db07961a03", | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Total params: 1.370216895 billion.\n" | |
] | |
} | |
], | |
"source": [ | |
"from huggingface_hub import parse_safetensors_file_metadata\n", | |
"total_params = 0\n", | |
"\n", | |
"for file in files_info:\n", | |
" if (\n", | |
" len(file.path.split(\"/\")[-1].split(\".\")) == 2\n", | |
" and \"safetensors\" in file.path\n", | |
" and any(k in file.path for k in components)\n", | |
" ):\n", | |
" total_params += parse_safetensors_file_metadata(\n", | |
" repo_id=repo_id, filename=file.path\n", | |
" ).parameter_count[\"F32\"]\n", | |
"\n", | |
"print(f\"Total params: {total_params / 1e9} billion.\")" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.9.17" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 5 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Some results:
runwayml/stable-diffusion-v1-5
stabilityai/stable-diffusion-xl-base-1.0
stabilityai/stable-video-diffusion-img2vid-xt