Last active
January 22, 2024 04:24
-
-
Save sayakpaul/9ec9eb1e915a587184e3029c116f803b to your computer and use it in GitHub Desktop.
PoCs a utility to get the totsl parameters count of a diffusers pipeline without hefty downloads.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from huggingface_hub import ( | |
hf_hub_download, | |
list_files_info, | |
parse_safetensors_file_metadata, | |
) | |
import json | |
import argparse | |
def get_files_info(args): | |
files_info = list_files_info(args.repo_id, expand=True) | |
files_info = list(files_info) | |
return files_info | |
def get_model_index(args): | |
model_index_path = hf_hub_download( | |
repo_id=args.repo_id, filename="model_index.json" | |
) | |
with open(model_index_path) as f: | |
model_index = json.load(f) | |
return model_index | |
def get_params(args): | |
model_index = get_model_index(args) | |
components = [k for k in model_index.keys() if isinstance(model_index[k], list)] | |
repo_files_info = get_files_info(args) | |
total_params = 0 | |
for file in repo_files_info: | |
if ( | |
len(file.path.split("/")[-1].split(".")) == 2 | |
and "safetensors" in file.path | |
and any(k in file.path for k in components) | |
): | |
total_params += parse_safetensors_file_metadata( | |
repo_id=args.repo_id, filename=file.path | |
).parameter_count["F32"] | |
return total_params | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Simple example of a training script.") | |
parser.add_argument( | |
"--repo_id", | |
type=str, | |
default="runwayml/stable-diffusion-v1-5", | |
help="Repo id to count the params for.", | |
) | |
args = parser.parse_args() | |
params = get_params(args) | |
print(f"{args.repo_id} has {params / 1e9} billion parameters.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Some results:
runwayml/stable-diffusion-v1-5
stabilityai/stable-diffusion-xl-base-1.0
stabilityai/stable-video-diffusion-img2vid-xt