Created
March 20, 2024 22:36
-
-
Save mlazos/1171f035a2392c33778aaa3d7bf24370 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch._dynamo as torchdynamo | |
import torch._inductor | |
import time | |
import torch._inductor.config as config | |
from torch._dynamo.utils import cprofile_wrapper | |
from apex.optimizers import FusedAdam, FusedSGD | |
config.triton.cudagraphs = True | |
config.cpp_wrapper = False | |
opt_to_apex = { | |
torch.optim.Adam: FusedAdam, | |
torch.optim.SGD: FusedSGD, | |
torch.optim.AdamW: FusedAdam, | |
} | |
with open("hf_models.txt", "r") as f: | |
hf_models = f.read().split("\n") | |
with open("timm_models.txt", "r") as f: | |
timm_models = f.read().split("\n") | |
with open("tb_models.txt", "r") as f: | |
tb_models = f.read().split("\n") | |
def get_model(name): | |
if name in hf_models: | |
return get_hf_model(name) | |
elif name in timm_models: | |
return get_timm_model(name) | |
elif name in tb_models(name): | |
return get_torchbench_model(name) | |
else: | |
raise "No valid model in hf, timm, or tb" | |
out_file = open("bench.txt", "a") | |
def bench(f, file_name, model_name, iters=100, warmup=5, profile=False): | |
for _ in range(warmup): | |
f() | |
if profile: | |
torch.cuda.synchronize() | |
with torch.profiler.profile() as prof: | |
for _ in range(1): | |
f() | |
torch.cuda.synchronize() | |
prof.export_chrome_trace(f"{model_name}_{file_name}.json") | |
return None | |
else: | |
torch.cuda.synchronize() | |
begin = time.time() | |
for _ in range(iters): | |
f() | |
torch.cuda.synchronize() | |
us_per_iter = (time.time() - begin) * 1e6 / iters | |
res = f"{model_name},{us_per_iter},us" | |
print(res) | |
with open(f"bench_{file_name}.txt", "a") as out_file: | |
print(res, file=out_file) | |
return us_per_iter | |
def get_torchbench_model(name): | |
import importlib | |
import sys | |
sys.path.append("/scratch/mlazos/torchbenchmark") | |
module = importlib.import_module(f"torchbenchmark.models.{name}") | |
benchmark_cls = module.Model | |
benchmark = benchmark_cls(test="train", device="cuda", batch_size=16) | |
model, _ = benchmark.get_module() | |
return model | |
def get_hf_model(name): | |
import importlib | |
module = importlib.import_module("transformers") | |
model_cls = getattr(module, name) | |
config_cls = model_cls.config_class | |
config = config_cls() | |
if "auto" in model_cls.__module__: | |
# Handle auto classes | |
model = model_cls.from_config(config).to(device="cuda", dtype=torch.float32) | |
else: | |
model = model_cls(config).to(device="cuda", dtype=torch.float32) | |
return model | |
def get_timm_model(name): | |
import timm | |
model = timm.create_model(name) | |
return model.to(device="cuda", dtype=torch.float32) | |
def opt_exec_benchmark(model_name, profile, opt_ctor, **kwargs): | |
print( | |
f"Running tracing benchmark on {model_name} with {opt_ctor.__name__}", | |
file=out_file, | |
) | |
print(model_name) | |
model = get_model(model_name) | |
param_limit = 1000 | |
param_list = list(model.parameters()) | |
lower = 0 | |
upper = min(len(param_list), param_limit) | |
params = param_list[lower:upper] | |
print(f"num_params:{len(params)}") | |
for p in params: | |
p.grad = torch.rand_like(p) | |
torch._dynamo.mark_static_address(p.grad) | |
opt_eager = opt_ctor(params, **kwargs, lr=0.01) | |
opt_eager.step() | |
opt_eager.step() | |
# bench( | |
# lambda: opt_eager.step(), | |
# file_name=f"{opt_ctor.__name__}_eager", | |
# model_name=model_name, | |
# profile=profile, | |
# iters=100, | |
# ) | |
torch.set_grad_enabled(False) | |
opt = opt_ctor(params, **kwargs, lr=0.01) | |
opt_step = torchdynamo.optimize("inductor")(opt.step) | |
bench( | |
lambda: opt_step(), | |
file_name=f"{opt_ctor.__name__}_dynamo", | |
model_name=model_name, | |
profile=profile, | |
iters=100, | |
) | |
# if opt_ctor in opt_to_apex: | |
if False: | |
# Adam signature: | |
# params, lr=0.001, bias_correction=True, betas=(0.9, 0.999), eps=1e-08, adam_w_mode=True, weight_decay=0.0, amsgrad=False, set_grad_none=True | |
apex_ctor = opt_to_apex[opt_ctor] | |
if opt_ctor is torch.optim.AdamW: | |
pass | |
kwargs.pop("foreach", None) | |
apex_opt = apex_ctor(params, **kwargs, lr=0.01, set_grad_none=False) | |
# g = torch.cuda.CUDAGraph() | |
# with torch.cuda.graph(g): | |
# apex_opt.step() | |
bench( | |
lambda: apex_opt.step(), | |
file_name=f"{opt_ctor.__name__}_apex", | |
model_name=model_name, | |
profile=profile, | |
iters=100, | |
) | |
torch._dynamo.reset() | |
profile = True | |
optims = [ | |
torch.optim.Adam, | |
] # [torch.optim.Adam, torch.optim.Adagrad, torch.optim.AdamW, torch.optim.ASGD, torch.optim.NAdam, torch.optim.RMSprop, torch.optim.Rprop, torch.optim.Adadelta] | |
models = timm_models | |
for optim in optims: | |
for model in models: | |
opt_exec_benchmark( | |
models, | |
profile, | |
optim, | |
foreach=True, | |
) | |
print("", file=out_file) | |
break | |
# opt_exec_benchmark("BartForConditionalGeneration", False, optim, foreach=True) | |
# opt_exec_benchmark("PegasusForConditionalGeneration", False, optim, foreach=True) | |
# opt_exec_benchmark("resnet18", profile, optim) | |
# opt_exec_benchmark("MobileBertForMaskedLM", profile, torch.optim.Adam, foreach=False, lr=0.01) | |
# print("", file=out_file) | |
out_file.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment