Created
October 20, 2023 16:59
-
-
Save mikaylagawarecki/ef0cf6e1c2effd5ffc03eaf4461f9692 to your computer and use it in GitHub Desktop.
Breakdown for dispatcher registrations
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from torchgen.model import DispatchKey | |
num_impl_registrations = 0 | |
registration_dict = {} | |
failed_keys = set() | |
# === Collect number of kernel registrations for each dispatch key === | |
for key in DispatchKey: | |
try: | |
num_kernel_registrations = len(torch._C._dispatch_get_registrations_for_dispatch_key(str(key))) | |
num_impl_registrations += num_kernel_registrations | |
registration_dict[key] = num_kernel_registrations | |
except: | |
# Sanity checking the failed keys we see that they are keys that we expect no registrations for (e.g. NestedTensorHPU) | |
failed_keys.add(str(key)) | |
continue | |
num_schema_registrations = len(torch._C._dispatch_get_registrations_for_dispatch_key()) | |
print(f"Registrations [{num_schema_registrations}] Schema Registrations") | |
print(f"Registrations [{num_impl_registrations}] Impl Registrations") | |
# === Sort dispatch keys by number of impl registrations === | |
long_tail = 0 | |
for key in sorted(registration_dict, key=registration_dict.get, reverse=True): | |
if registration_dict[key] > 50: | |
print(f"Impl Registrations [{registration_dict[key]}] {key}") | |
else: | |
long_tail += registration_dict[key] | |
print(f"Impl Registrations [{long_tail}] The Rest") | |
schemas = torch._C._dispatch_get_registrations_for_dispatch_key() | |
schema_dict = {} | |
for schema in schemas: | |
ns, schema = schema.split('::') | |
if ns not in schema_dict: | |
schema_dict[ns] = 1 | |
else: | |
schema_dict[ns] += 1 | |
# === Sort namespaces by number of schema registrations === | |
long_tail = 0 | |
for key in sorted(schema_dict, key=schema_dict.get, reverse=True): | |
if schema_dict[key] > 20: | |
print(f"Schema Registrations [{schema_dict[key]}] {key}") | |
else: | |
long_tail += schema_dict[key] | |
print(f"Schema Registrations [{long_tail}] The Rest") | |
# === Obtain stats of impls registered using Python library === | |
python_impls = {} | |
for k in torch.library._impls: | |
ns, name, dk = k.split("/") | |
if dk in python_impls: | |
python_impls[dk].append(k) | |
else: | |
python_impls[dk] = [k] | |
for k in python_impls: | |
print(f"{k} [{len(python_impls[k])}] python impl") | |
# ===These were obtained by hacking torch.Library to add defs to a global dict === | |
print("prims [124] python schema") | |
print("the rest [2] python schema") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment