Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save mikaylagawarecki/ef0cf6e1c2effd5ffc03eaf4461f9692 to your computer and use it in GitHub Desktop.
Save mikaylagawarecki/ef0cf6e1c2effd5ffc03eaf4461f9692 to your computer and use it in GitHub Desktop.
Breakdown for dispatcher registrations
import torch
from torchgen.model import DispatchKey
num_impl_registrations = 0
registration_dict = {}
failed_keys = set()
# === Collect number of kernel registrations for each dispatch key ===
for key in DispatchKey:
try:
num_kernel_registrations = len(torch._C._dispatch_get_registrations_for_dispatch_key(str(key)))
num_impl_registrations += num_kernel_registrations
registration_dict[key] = num_kernel_registrations
except:
# Sanity checking the failed keys we see that they are keys that we expect no registrations for (e.g. NestedTensorHPU)
failed_keys.add(str(key))
continue
num_schema_registrations = len(torch._C._dispatch_get_registrations_for_dispatch_key())
print(f"Registrations [{num_schema_registrations}] Schema Registrations")
print(f"Registrations [{num_impl_registrations}] Impl Registrations")
# === Sort dispatch keys by number of impl registrations ===
long_tail = 0
for key in sorted(registration_dict, key=registration_dict.get, reverse=True):
if registration_dict[key] > 50:
print(f"Impl Registrations [{registration_dict[key]}] {key}")
else:
long_tail += registration_dict[key]
print(f"Impl Registrations [{long_tail}] The Rest")
schemas = torch._C._dispatch_get_registrations_for_dispatch_key()
schema_dict = {}
for schema in schemas:
ns, schema = schema.split('::')
if ns not in schema_dict:
schema_dict[ns] = 1
else:
schema_dict[ns] += 1
# === Sort namespaces by number of schema registrations ===
long_tail = 0
for key in sorted(schema_dict, key=schema_dict.get, reverse=True):
if schema_dict[key] > 20:
print(f"Schema Registrations [{schema_dict[key]}] {key}")
else:
long_tail += schema_dict[key]
print(f"Schema Registrations [{long_tail}] The Rest")
# === Obtain stats of impls registered using Python library ===
python_impls = {}
for k in torch.library._impls:
ns, name, dk = k.split("/")
if dk in python_impls:
python_impls[dk].append(k)
else:
python_impls[dk] = [k]
for k in python_impls:
print(f"{k} [{len(python_impls[k])}] python impl")
# ===These were obtained by hacking torch.Library to add defs to a global dict ===
print("prims [124] python schema")
print("the rest [2] python schema")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment