Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active January 22, 2024 04:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sayakpaul/9ec9eb1e915a587184e3029c116f803b to your computer and use it in GitHub Desktop.
Save sayakpaul/9ec9eb1e915a587184e3029c116f803b to your computer and use it in GitHub Desktop.
PoCs a utility to get the totsl parameters count of a diffusers pipeline without hefty downloads.
from huggingface_hub import (
hf_hub_download,
list_files_info,
parse_safetensors_file_metadata,
)
import json
import argparse
def get_files_info(args):
files_info = list_files_info(args.repo_id, expand=True)
files_info = list(files_info)
return files_info
def get_model_index(args):
model_index_path = hf_hub_download(
repo_id=args.repo_id, filename="model_index.json"
)
with open(model_index_path) as f:
model_index = json.load(f)
return model_index
def get_params(args):
model_index = get_model_index(args)
components = [k for k in model_index.keys() if isinstance(model_index[k], list)]
repo_files_info = get_files_info(args)
total_params = 0
for file in repo_files_info:
if (
len(file.path.split("/")[-1].split(".")) == 2
and "safetensors" in file.path
and any(k in file.path for k in components)
):
total_params += parse_safetensors_file_metadata(
repo_id=args.repo_id, filename=file.path
).parameter_count["F32"]
return total_params
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--repo_id",
type=str,
default="runwayml/stable-diffusion-v1-5",
help="Repo id to count the params for.",
)
args = parser.parse_args()
params = get_params(args)
print(f"{args.repo_id} has {params / 1e9} billion parameters.")
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@sayakpaul
Copy link
Author

Some results:

runwayml/stable-diffusion-v1-5

python get_params_diffusers_diffusers.py

runwayml/stable-diffusion-v1-5 has 1.370216895 billion parameters.

stabilityai/stable-diffusion-xl-base-1.0

python get_params_diffusers_diffusers.py --repo_id=stabilityai/stable-diffusion-xl-base-1.0

stabilityai/stable-diffusion-xl-base-1.0 has 3.55249173 billion parameters.

stabilityai/stable-video-diffusion-img2vid-xt

python get_params_diffusers_diffusers.py --repo_id=stabilityai/stable-video-diffusion-img2vid-xt

stabilityai/stable-video-diffusion-img2vid-xt has 2.254442729 billion parameters.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment