Skip to content

Instantly share code, notes, and snippets.

@mlazos
Created March 20, 2024 22:36
Show Gist options
  • Save mlazos/1171f035a2392c33778aaa3d7bf24370 to your computer and use it in GitHub Desktop.
Save mlazos/1171f035a2392c33778aaa3d7bf24370 to your computer and use it in GitHub Desktop.
import torch
import torch._dynamo as torchdynamo
import torch._inductor
import time
import torch._inductor.config as config
from torch._dynamo.utils import cprofile_wrapper
from apex.optimizers import FusedAdam, FusedSGD
config.triton.cudagraphs = True
config.cpp_wrapper = False
opt_to_apex = {
torch.optim.Adam: FusedAdam,
torch.optim.SGD: FusedSGD,
torch.optim.AdamW: FusedAdam,
}
with open("hf_models.txt", "r") as f:
hf_models = f.read().split("\n")
with open("timm_models.txt", "r") as f:
timm_models = f.read().split("\n")
with open("tb_models.txt", "r") as f:
tb_models = f.read().split("\n")
def get_model(name):
if name in hf_models:
return get_hf_model(name)
elif name in timm_models:
return get_timm_model(name)
elif name in tb_models(name):
return get_torchbench_model(name)
else:
raise "No valid model in hf, timm, or tb"
out_file = open("bench.txt", "a")
def bench(f, file_name, model_name, iters=100, warmup=5, profile=False):
for _ in range(warmup):
f()
if profile:
torch.cuda.synchronize()
with torch.profiler.profile() as prof:
for _ in range(1):
f()
torch.cuda.synchronize()
prof.export_chrome_trace(f"{model_name}_{file_name}.json")
return None
else:
torch.cuda.synchronize()
begin = time.time()
for _ in range(iters):
f()
torch.cuda.synchronize()
us_per_iter = (time.time() - begin) * 1e6 / iters
res = f"{model_name},{us_per_iter},us"
print(res)
with open(f"bench_{file_name}.txt", "a") as out_file:
print(res, file=out_file)
return us_per_iter
def get_torchbench_model(name):
import importlib
import sys
sys.path.append("/scratch/mlazos/torchbenchmark")
module = importlib.import_module(f"torchbenchmark.models.{name}")
benchmark_cls = module.Model
benchmark = benchmark_cls(test="train", device="cuda", batch_size=16)
model, _ = benchmark.get_module()
return model
def get_hf_model(name):
import importlib
module = importlib.import_module("transformers")
model_cls = getattr(module, name)
config_cls = model_cls.config_class
config = config_cls()
if "auto" in model_cls.__module__:
# Handle auto classes
model = model_cls.from_config(config).to(device="cuda", dtype=torch.float32)
else:
model = model_cls(config).to(device="cuda", dtype=torch.float32)
return model
def get_timm_model(name):
import timm
model = timm.create_model(name)
return model.to(device="cuda", dtype=torch.float32)
def opt_exec_benchmark(model_name, profile, opt_ctor, **kwargs):
print(
f"Running tracing benchmark on {model_name} with {opt_ctor.__name__}",
file=out_file,
)
print(model_name)
model = get_model(model_name)
param_limit = 1000
param_list = list(model.parameters())
lower = 0
upper = min(len(param_list), param_limit)
params = param_list[lower:upper]
print(f"num_params:{len(params)}")
for p in params:
p.grad = torch.rand_like(p)
torch._dynamo.mark_static_address(p.grad)
opt_eager = opt_ctor(params, **kwargs, lr=0.01)
opt_eager.step()
opt_eager.step()
# bench(
# lambda: opt_eager.step(),
# file_name=f"{opt_ctor.__name__}_eager",
# model_name=model_name,
# profile=profile,
# iters=100,
# )
torch.set_grad_enabled(False)
opt = opt_ctor(params, **kwargs, lr=0.01)
opt_step = torchdynamo.optimize("inductor")(opt.step)
bench(
lambda: opt_step(),
file_name=f"{opt_ctor.__name__}_dynamo",
model_name=model_name,
profile=profile,
iters=100,
)
# if opt_ctor in opt_to_apex:
if False:
# Adam signature:
# params, lr=0.001, bias_correction=True, betas=(0.9, 0.999), eps=1e-08, adam_w_mode=True, weight_decay=0.0, amsgrad=False, set_grad_none=True
apex_ctor = opt_to_apex[opt_ctor]
if opt_ctor is torch.optim.AdamW:
pass
kwargs.pop("foreach", None)
apex_opt = apex_ctor(params, **kwargs, lr=0.01, set_grad_none=False)
# g = torch.cuda.CUDAGraph()
# with torch.cuda.graph(g):
# apex_opt.step()
bench(
lambda: apex_opt.step(),
file_name=f"{opt_ctor.__name__}_apex",
model_name=model_name,
profile=profile,
iters=100,
)
torch._dynamo.reset()
profile = True
optims = [
torch.optim.Adam,
] # [torch.optim.Adam, torch.optim.Adagrad, torch.optim.AdamW, torch.optim.ASGD, torch.optim.NAdam, torch.optim.RMSprop, torch.optim.Rprop, torch.optim.Adadelta]
models = timm_models
for optim in optims:
for model in models:
opt_exec_benchmark(
models,
profile,
optim,
foreach=True,
)
print("", file=out_file)
break
# opt_exec_benchmark("BartForConditionalGeneration", False, optim, foreach=True)
# opt_exec_benchmark("PegasusForConditionalGeneration", False, optim, foreach=True)
# opt_exec_benchmark("resnet18", profile, optim)
# opt_exec_benchmark("MobileBertForMaskedLM", profile, torch.optim.Adam, foreach=False, lr=0.01)
# print("", file=out_file)
out_file.close()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment