Last active
May 8, 2024 04:20
-
-
Save mlazos/94b01f5f9c3386987d8712455014d371 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch._dynamo as torchdynamo | |
import torch._inductor | |
import time | |
import torch._inductor.config as config | |
from torch._dynamo.utils import cprofile_wrapper | |
from apex.optimizers import FusedAdam, FusedSGD | |
config.triton.cudagraphs = True | |
config.cpp_wrapper = False | |
apex_opts = { | |
torch.optim.Adam, | |
torch.optim.SGD, | |
torch.optim.AdamW, | |
} | |
with open("hf_models.txt", "r") as f: | |
hf_models = f.read().split("\n") | |
with open("timm_models.txt", "r") as f: | |
timm_models = f.read().split("\n") | |
with open("tb_models.txt", "r") as f: | |
tb_models = f.read().split("\n") | |
def get_model(name): | |
if name in hf_models: | |
return get_hf_model(name) | |
elif name in timm_models: | |
return get_timm_model(name) | |
elif name in tb_models: | |
return get_torchbench_model(name) | |
else: | |
raise "No valid model in hf, timm, or tb" | |
out_file = open("bench.txt", "a") | |
def bench(f, file_name, model_name, iters=100, warmup=5, profile=False): | |
for _ in range(warmup): | |
f() | |
if profile: | |
torch.cuda.synchronize() | |
with torch.profiler.profile() as prof: | |
for _ in range(1): | |
f() | |
torch.cuda.synchronize() | |
prof.export_chrome_trace(f"{model_name}_{file_name}.json") | |
return None | |
else: | |
torch.cuda.synchronize() | |
begin = time.time() | |
for _ in range(iters): | |
f() | |
torch.cuda.synchronize() | |
us_per_iter = (time.time() - begin) * 1e6 / iters | |
res = f"{model_name},{us_per_iter},us" | |
print(res) | |
with open(f"bench_{file_name}.txt", "a") as out_file: | |
print(res, file=out_file) | |
return us_per_iter | |
def get_torchbench_model(name): | |
import importlib | |
import sys | |
sys.path.append("/scratch/mlazos/torchbenchmark") | |
module = importlib.import_module(f"torchbenchmark.models.{name}") | |
benchmark_cls = module.Model | |
benchmark = benchmark_cls(test="train", device="cuda", batch_size=16) | |
model, _ = benchmark.get_module() | |
return model | |
def get_hf_model(name): | |
import importlib | |
module = importlib.import_module("transformers") | |
model_cls = getattr(module, name) | |
config_cls = model_cls.config_class | |
config = config_cls() | |
if "auto" in model_cls.__module__: | |
# Handle auto classes | |
model = model_cls.from_config(config).to(device="cuda", dtype=torch.float32) | |
else: | |
model = model_cls(config).to(device="cuda", dtype=torch.float32) | |
return model | |
def get_timm_model(name): | |
import timm | |
model = timm.create_model(name) | |
return model.to(device="cuda", dtype=torch.float32) | |
def opt_exec_benchmark(model_name, profile, opt_ctor, **kwargs): | |
print( | |
f"Running tracing benchmark on {model_name} with {opt_ctor.__name__}", | |
file=out_file, | |
) | |
print(model_name) | |
model = get_model(model_name) | |
param_limit = 1000 | |
param_list = list(model.parameters()) | |
lower = 0 | |
upper = min(len(param_list), param_limit) | |
params = param_list[lower:upper] | |
print(f"num_params:{len(params)}") | |
for p in params: | |
p.grad = torch.rand_like(p) | |
torch._dynamo.mark_static_address(p.grad) | |
opt_eager = opt_ctor(params, **kwargs, lr=0.01) | |
opt_eager.step() | |
opt_eager.step() | |
bench( | |
lambda: opt_eager.step(), | |
file_name=f"{opt_ctor.__name__}_eager", | |
model_name=model_name, | |
profile=profile, | |
iters=250, | |
) | |
torch.set_grad_enabled(False) | |
opt = opt_ctor(params, **kwargs, lr=0.01, weight_decay=0.01) | |
opt_step = torchdynamo.optimize("inductor")(opt.step) | |
bench( | |
lambda: opt_step(), | |
file_name=f"{opt_ctor.__name__}_dynamo", | |
model_name=model_name, | |
profile=profile, | |
iters=250, | |
) | |
if opt_ctor in apex_opts: | |
# Adam signature: | |
# params, lr=0.001, bias_correction=True, betas=(0.9, 0.999), eps=1e-08, adam_w_mode=True, weight_decay=0.0, amsgrad=False, set_grad_none=True | |
kwargs.pop("foreach", None) | |
if opt_ctor in (torch.optim.Adam, torch.optim.AdamW): | |
apex_opt = FusedAdam( | |
params, | |
**kwargs, | |
lr=0.01, | |
betas=(0.9, 0.999), | |
eps=1e-8, | |
weight_decay=0.0 if opt_ctor != torch.optim.AdamW else 0.01, | |
adam_w_mode=opt_ctor == torch.optim.AdamW, | |
amsgrad=False, | |
capturable=True, | |
set_grad_none=False, | |
) | |
else: | |
apex_opt = FusedSGD( | |
params, | |
**kwargs, | |
lr=0.01, | |
set_grad_none=False, | |
) | |
bench( | |
lambda: apex_opt.step(), | |
file_name=f"{opt_ctor.__name__}_apex", | |
model_name=model_name, | |
profile=profile, | |
iters=250, | |
) | |
torch._dynamo.reset() | |
profile = True | |
optims = [ | |
torch.optim.Adam, | |
torch.optim.AdamW, | |
torch.optim.SGD, | |
] | |
models = timm_models | |
import itertools | |
for optim in optims: | |
for model in hf_models: # itertools.chain(hf_models, timm_models, tb_models): | |
opt_exec_benchmark( | |
model, | |
profile, | |
optim, | |
foreach=True, | |
) | |
print("", file=out_file) | |
out_file.close() |
Hi @mlazos in this script, could you tell me how gradients get populated to parameters?
Line 118, just populated with a rand_like
ah thank you very much
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi @mlazos in this script, could you tell me how gradients get populated to parameters?