Skip to content

Instantly share code, notes, and snippets.

@AmdSampsa
Last active May 2, 2025 12:47
Show Gist options
  • Save AmdSampsa/de0c7dc756ec38934fa8342853318f72 to your computer and use it in GitHub Desktop.
Save AmdSampsa/de0c7dc756ec38934fa8342853318f72 to your computer and use it in GitHub Desktop.
timm minimal repro for inductor dashboard
from pprint import pprint
import torch
import os
#import timm
# from timm.models import dpn107
from timm.models._factory import create_model
# from torch._inductor import config as inductor_config
from torch._dynamo.utils import same
#"""# some random stuff from the test suite
torch.use_deterministic_algorithms(True) # depends on the model
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.allow_tf32 = False
torch.backends.cudnn.benchmark = False
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(False)
# inductor_config.fallback_random = True
# """
device="cuda:0"
### define gold-standard precision
# gs_dtype=torch.float32
gs_dtype=torch.bfloat16
### what we're comparing againts to
dtype=torch.bfloat16
# dtype=torch.float32
### is the comparison done against an inductor-compiled model or not?
do_compile=True
# do_compile=False
inductor_options={}
#inductor_options["triton.cudagraphs"] =True
#inductor_options["inductor_config.fallback_random"] = True # used by the script..?
def compare_torch(tensor1: torch.Tensor, tensor2: torch.Tensor) -> dict:
if tensor1.shape != tensor2.shape:
raise ValueError(f"Shape mismatch: {tensor1.shape} vs {tensor2.shape}")
t1=tensor1.detach()
t2=tensor2.detach()
abs_diff = torch.abs(t1-t2)
max_diff = torch.max(abs_diff).item()
mean_diff = torch.mean(abs_diff).item()
l2_diff = torch.sqrt(torch.mean(torch.square(abs_diff))).item()
return {
'max_diff': max_diff,
'mean_diff': mean_diff,
'l2_diff': l2_diff,
'shape': tuple(t1.shape),
'arr1_mean': torch.mean(t1).item(),
'arr2_mean': torch.mean(t2).item(),
'arr1_std': torch.std(t1).item(),
'arr2_std': torch.std(t2).item()
}
torch.manual_seed(42)
print("creating the model")
kwargs={'in_chans': 3, 'drop_rate': 0.0}
model_gs=create_model("dpn107", pretrained=True, **kwargs).to(device=device, dtype=gs_dtype) # model with the gold-standard dtype
model=create_model("dpn107", pretrained=True, **kwargs).to(device=device, dtype=dtype)
## TODO: timm has a get_sample_tensor function?
t = torch.randn((8, 3, 224, 224), device=device) # NOTE: batch size 8
# Scale to match the target range (-1.7266 to 1.7266)
target_min = -1.7266
target_max = 1.7266
t = t * (target_max - target_min) / (t.max() - t.min()) * 0.5 + (target_min + target_max) / 2
# Convert to bfloat16
t_gs = t.clone().to(dtype=gs_dtype)
t2 = t.clone().to(dtype=dtype)
t3 = t.clone().to(dtype=dtype)
"""
# Verify the characteristics
print(f"Min: {t.min()}")
print(f"Max: {t.max()}")
print(f"Mean: {t.mean()}")
"""
print("running the gold-standard model with", gs_dtype)
r_gs=model_gs(t_gs)
print("will compare to", dtype)
if do_compile:
print("comparing against compiled model")
compiled_model = torch.compile(
model,
backend="inductor",
options=inductor_options
)
print("running the compiled model (first run takes ages..)")
r2=compiled_model(t2)
print("replaying the model (fast?)")
r3=compiled_model(t3)
else:
r2=model(t2)
r3=model(t3)
print()
print("comparison to first run")
pprint(compare_torch(r_gs,r2))
print()
print("comparison to second run")
pprint(compare_torch(r_gs,r3))
same(r_gs, r2, tol=0.001, equal_nan=True, exact_dtype=True)
# not: for two eager runs tol=0
### MI300
"""two float32 eager runs
-> OK with deterministic algorithms, goes to zero
"""
"""float32 eager vs. float32 compile
'max_diff': 4.634261131286621e-06,
'mean_diff': 8.079827011897578e-07,
"""
"""float32 eager and bfloat16 compile
'max_diff': 0.03494948148727417,
'mean_diff': 0.005384537391364574,
"""
"""bfloat16 eager and bfloat16 compile ## NOTE: this is the test done by the benchmark suite
'max_diff': 0.0263671875,
'mean_diff': 0.004791259765625,
## -> function same() complains (see above)
"""
### A100
"""float32 eager and bfloat16 compile
'max_diff': 0.024984657764434814,
'mean_diff': 0.004882916808128357,
"""
"""bfloat16 eager and bfloat16 compile
'max_diff': 0.02734375,
'mean_diff': 0.00439453125
## -> the same crap! again, funcion same() complains
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment