Last active
May 22, 2024 19:52
-
-
Save mlazos/94b01f5f9c3386987d8712455014d371 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch._dynamo as torchdynamo | |
import torch._inductor | |
import time | |
import torch._inductor.config as config | |
from torch._dynamo.utils import cprofile_wrapper | |
from apex.optimizers import FusedAdam, FusedSGD | |
config.triton.cudagraphs = True | |
config.cpp_wrapper = False | |
apex_opts = { | |
torch.optim.Adam, | |
torch.optim.SGD, | |
torch.optim.AdamW, | |
} | |
with open("hf_models.txt", "r") as f: | |
hf_models = f.read().split("\n") | |
with open("timm_models.txt", "r") as f: | |
timm_models = f.read().split("\n") | |
with open("tb_models.txt", "r") as f: | |
tb_models = f.read().split("\n") | |
def get_model(name): | |
if name in hf_models: | |
return get_hf_model(name) | |
elif name in timm_models: | |
return get_timm_model(name) | |
elif name in tb_models: | |
return get_torchbench_model(name) | |
else: | |
raise "No valid model in hf, timm, or tb" | |
out_file = open("bench.txt", "a") | |
def bench(f, file_name, model_name, iters=100, warmup=5, profile=False): | |
for _ in range(warmup): | |
f() | |
if profile: | |
torch.cuda.synchronize() | |
with torch.profiler.profile() as prof: | |
for _ in range(1): | |
f() | |
torch.cuda.synchronize() | |
prof.export_chrome_trace(f"{model_name}_{file_name}.json") | |
return None | |
else: | |
torch.cuda.synchronize() | |
begin = time.time() | |
for _ in range(iters): | |
f() | |
torch.cuda.synchronize() | |
us_per_iter = (time.time() - begin) * 1e6 / iters | |
res = f"{model_name},{us_per_iter},us" | |
print(res) | |
with open(f"bench_{file_name}.txt", "a") as out_file: | |
print(res, file=out_file) | |
return us_per_iter | |
def get_torchbench_model(name): | |
import importlib | |
import sys | |
sys.path.append("/scratch/mlazos/torchbenchmark") | |
module = importlib.import_module(f"torchbenchmark.models.{name}") | |
benchmark_cls = module.Model | |
benchmark = benchmark_cls(test="train", device="cuda", batch_size=16) | |
model, _ = benchmark.get_module() | |
return model | |
def get_hf_model(name): | |
import importlib | |
module = importlib.import_module("transformers") | |
model_cls = getattr(module, name) | |
config_cls = model_cls.config_class | |
config = config_cls() | |
if "auto" in model_cls.__module__: | |
# Handle auto classes | |
model = model_cls.from_config(config).to(device="cuda", dtype=torch.float32) | |
else: | |
model = model_cls(config).to(device="cuda", dtype=torch.float32) | |
return model | |
def get_timm_model(name): | |
import timm | |
model = timm.create_model(name) | |
return model.to(device="cuda", dtype=torch.float32) | |
def opt_exec_benchmark(model_name, profile, opt_ctor, **kwargs): | |
print( | |
f"Running tracing benchmark on {model_name} with {opt_ctor.__name__}", | |
file=out_file, | |
) | |
print(model_name) | |
model = get_model(model_name) | |
param_limit = 1000 | |
param_list = list(model.parameters()) | |
lower = 0 | |
upper = min(len(param_list), param_limit) | |
params = param_list[lower:upper] | |
print(f"num_params:{len(params)}") | |
for p in params: | |
p.grad = torch.rand_like(p) | |
torch._dynamo.mark_static_address(p.grad) | |
opt_eager = opt_ctor(params, **kwargs, lr=0.01) | |
opt_eager.step() | |
opt_eager.step() | |
bench( | |
lambda: opt_eager.step(), | |
file_name=f"{opt_ctor.__name__}_eager", | |
model_name=model_name, | |
profile=profile, | |
iters=250, | |
) | |
torch.set_grad_enabled(False) | |
opt = opt_ctor(params, **kwargs, lr=0.01, weight_decay=0.01) | |
opt_step = torchdynamo.optimize("inductor")(opt.step) | |
bench( | |
lambda: opt_step(), | |
file_name=f"{opt_ctor.__name__}_dynamo", | |
model_name=model_name, | |
profile=profile, | |
iters=250, | |
) | |
if opt_ctor in apex_opts: | |
# Adam signature: | |
# params, lr=0.001, bias_correction=True, betas=(0.9, 0.999), eps=1e-08, adam_w_mode=True, weight_decay=0.0, amsgrad=False, set_grad_none=True | |
kwargs.pop("foreach", None) | |
if opt_ctor in (torch.optim.Adam, torch.optim.AdamW): | |
apex_opt = FusedAdam( | |
params, | |
**kwargs, | |
lr=0.01, | |
betas=(0.9, 0.999), | |
eps=1e-8, | |
weight_decay=0.0 if opt_ctor != torch.optim.AdamW else 0.01, | |
adam_w_mode=opt_ctor == torch.optim.AdamW, | |
amsgrad=False, | |
capturable=True, | |
set_grad_none=False, | |
) | |
else: | |
apex_opt = FusedSGD( | |
params, | |
**kwargs, | |
lr=0.01, | |
set_grad_none=False, | |
) | |
bench( | |
lambda: apex_opt.step(), | |
file_name=f"{opt_ctor.__name__}_apex", | |
model_name=model_name, | |
profile=profile, | |
iters=250, | |
) | |
torch._dynamo.reset() | |
profile = True | |
optims = [ | |
torch.optim.Adam, | |
torch.optim.AdamW, | |
torch.optim.SGD, | |
] | |
models = timm_models | |
import itertools | |
for optim in optims: | |
for model in hf_models: # itertools.chain(hf_models, timm_models, tb_models): | |
opt_exec_benchmark( | |
model, | |
profile, | |
optim, | |
foreach=True, | |
) | |
print("", file=out_file) | |
out_file.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You can look in scheduler.py and lowering.py (Look for foreach ops). There isn't a ton of benefit to fusing those kernels since the step increment is so small. You could do it and broadcast, but the improvement will likely be very minor because the step increment is basically just the cost of launching the kernel vs other kernels which are memory bound due to loading the params.