Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active January 22, 2024 04:24
Show Gist options
  • Save sayakpaul/9ec9eb1e915a587184e3029c116f803b to your computer and use it in GitHub Desktop.
Save sayakpaul/9ec9eb1e915a587184e3029c116f803b to your computer and use it in GitHub Desktop.
PoCs a utility to get the totsl parameters count of a diffusers pipeline without hefty downloads.
from huggingface_hub import (
hf_hub_download,
list_files_info,
parse_safetensors_file_metadata,
)
import json
import argparse
def get_files_info(args):
files_info = list_files_info(args.repo_id, expand=True)
files_info = list(files_info)
return files_info
def get_model_index(args):
model_index_path = hf_hub_download(
repo_id=args.repo_id, filename="model_index.json"
)
with open(model_index_path) as f:
model_index = json.load(f)
return model_index
def get_params(args):
model_index = get_model_index(args)
components = [k for k in model_index.keys() if isinstance(model_index[k], list)]
repo_files_info = get_files_info(args)
total_params = 0
for file in repo_files_info:
if (
len(file.path.split("/")[-1].split(".")) == 2
and "safetensors" in file.path
and any(k in file.path for k in components)
):
total_params += parse_safetensors_file_metadata(
repo_id=args.repo_id, filename=file.path
).parameter_count["F32"]
return total_params
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--repo_id",
type=str,
default="runwayml/stable-diffusion-v1-5",
help="Repo id to count the params for.",
)
args = parser.parse_args()
params = get_params(args)
print(f"{args.repo_id} has {params / 1e9} billion parameters.")
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "eb50220e-ea67-45d4-be7b-01f824118dcd",
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import hf_hub_download, list_files_info\n",
"import safetensors\n",
"import json"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f59a730d-361c-4a85-a49b-b989a3198343",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/sayakpaul/miniconda3/envs/diffusers/lib/python3.9/site-packages/huggingface_hub/utils/_deprecation.py:131: FutureWarning: 'list_files_info' (from 'huggingface_hub.hf_api') is deprecated and will be removed from version '0.23'. Use `list_repo_tree` and `get_paths_info` instead.\n",
" warnings.warn(warning_message, FutureWarning)\n"
]
},
{
"data": {
"text/plain": [
"[RepoFile(path='.gitattributes', size=1548, blob_id='55d2855c5be698e0572b9f42af95f06bfd5fb002', lfs=None, last_commit={'oid': 'd8c8d262a38ad3da5e673f048db73a504f2a92a9', 'title': 'Upload v1-5-pruned.ckpt', 'date': datetime.datetime(2022, 10, 20, 12, 18, 52, tzinfo=datetime.timezone.utc)}, security={'safe': True, 'av_scan': {'virusFound': False, 'virusNames': None}, 'pickle_import_scan': None}),\n",
" RepoFile(path='README.md', size=14462, blob_id='103bad8a83037abcdb9d24f2b21eba8792b33a5d', lfs=None, last_commit={'oid': '889b629140e71758e1e0006e355c331a5744b4bf', 'title': 'Update README.md', 'date': datetime.datetime(2022, 12, 19, 15, 29, 28, tzinfo=datetime.timezone.utc)}, security={'safe': True, 'av_scan': {'virusFound': False, 'virusNames': None}, 'pickle_import_scan': None}),\n",
" RepoFile(path='feature_extractor/preprocessor_config.json', size=342, blob_id='5294955ff7801083f720b34b55d0f1f51313c5c5', lfs=None, last_commit={'oid': '7621c1d34cb8951ae5277a5d1c07c431e46abb48', 'title': 'add diffusers weights', 'date': datetime.datetime(2022, 10, 20, 9, 30, 42, tzinfo=datetime.timezone.utc)}, security={'safe': True, 'av_scan': {'virusFound': False, 'virusNames': None}, 'pickle_import_scan': None}),\n",
" RepoFile(path='model_index.json', size=541, blob_id='daf7e2e2dfc64fb437a2b44525667111b00cb9fc', lfs=None, last_commit={'oid': 'aa9ba505e1973ae5cd05f5aedd345178f52f8e6a', 'title': 'Fix deprecation warning by changing `CLIPFeatureExtractor` to `CLIPImageProcessor`. (#124)', 'date': datetime.datetime(2023, 5, 5, 10, 38, 18, tzinfo=datetime.timezone.utc)}, security={'safe': True, 'av_scan': {'virusFound': False, 'virusNames': None}, 'pickle_import_scan': None})]"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"repo_id = \"runwayml/stable-diffusion-v1-5\"\n",
"\n",
"files_info = list_files_info(repo_id, expand=True)\n",
"files_info = list(files_info)\n",
"files_info[:4]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "49f21f4a-d9f5-46f5-ad90-89e18cc4eb50",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['feature_extractor',\n",
" 'safety_checker',\n",
" 'scheduler',\n",
" 'text_encoder',\n",
" 'tokenizer',\n",
" 'unet',\n",
" 'vae']"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_index_path = hf_hub_download(repo_id=repo_id, filename=\"model_index.json\")\n",
"\n",
"with open(model_index_path) as f:\n",
" model_index = json.load(f)\n",
"\n",
"components = [k for k in model_index.keys() if isinstance(model_index[k], list)]\n",
"components"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0b3f9f8c-50ef-4099-a7a8-e2db07961a03",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total params: 1.370216895 billion.\n"
]
}
],
"source": [
"from huggingface_hub import parse_safetensors_file_metadata\n",
"total_params = 0\n",
"\n",
"for file in files_info:\n",
" if (\n",
" len(file.path.split(\"/\")[-1].split(\".\")) == 2\n",
" and \"safetensors\" in file.path\n",
" and any(k in file.path for k in components)\n",
" ):\n",
" total_params += parse_safetensors_file_metadata(\n",
" repo_id=repo_id, filename=file.path\n",
" ).parameter_count[\"F32\"]\n",
"\n",
"print(f\"Total params: {total_params / 1e9} billion.\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
@sayakpaul
Copy link
Author

Some results:

runwayml/stable-diffusion-v1-5

python get_params_diffusers_diffusers.py

runwayml/stable-diffusion-v1-5 has 1.370216895 billion parameters.

stabilityai/stable-diffusion-xl-base-1.0

python get_params_diffusers_diffusers.py --repo_id=stabilityai/stable-diffusion-xl-base-1.0

stabilityai/stable-diffusion-xl-base-1.0 has 3.55249173 billion parameters.

stabilityai/stable-video-diffusion-img2vid-xt

python get_params_diffusers_diffusers.py --repo_id=stabilityai/stable-video-diffusion-img2vid-xt

stabilityai/stable-video-diffusion-img2vid-xt has 2.254442729 billion parameters.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment