Created
February 16, 2023 18:26
-
-
Save reachtarunhere/b13b9426579bfde4e4455712c0dec864 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
A much shorter version of train.py for benchmarking | |
""" | |
import os | |
from contextlib import nullcontext | |
import numpy as np | |
import time | |
import torch | |
from model import GPTConfig, GPT | |
# ----------------------------------------------------------------------------- | |
batch_size = 8 | |
block_size = 1024 | |
bias = True | |
seed = 1337 | |
device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. | |
dtype = 'float32' # 'bfloat16' # 'float32' or 'bfloat16' or 'float16' | |
compile = True # use PyTorch 2.0 to compile the model to be faster | |
exec(open('configurator.py').read()) # overrides from command line or config file | |
# ----------------------------------------------------------------------------- | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
# torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul | |
# torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn | |
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul | |
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn | |
device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast | |
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] | |
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) | |
# data loading init | |
real_data = False | |
if real_data: | |
dataset = 'openwebtext' | |
data_dir = os.path.join('data', dataset) | |
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') | |
def get_batch(split): | |
data = train_data # note ignore split in benchmarking script | |
ix = torch.randint(len(data) - block_size, (batch_size,)) | |
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) | |
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) | |
x, y = x.to(device), y.to(device) | |
return x, y | |
else: | |
# alternatively, if fixed data is desired to not care about data loading | |
x = torch.randint(50257, (batch_size, block_size), device=device) | |
y = torch.randint(50257, (batch_size, block_size), device=device) | |
get_batch = lambda split: (x, y) | |
# model init | |
gptconf = GPTConfig( | |
block_size = block_size, # how far back does the model look? i.e. context size | |
n_layer = 12, n_head = 12, n_embd = 768, # size of the model | |
dropout = 0, # for determinism | |
bias = bias, | |
) | |
model = GPT(gptconf) | |
model.to(device) | |
optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95)) | |
if compile: | |
print("Compiling model...") | |
model = torch.compile(model) # pytorch 2.0 | |
def timed(fn): | |
start = torch.cuda.Event(enable_timing=True) | |
end = torch.cuda.Event(enable_timing=True) | |
start.record() | |
result = fn() | |
end.record() | |
torch.cuda.synchronize() | |
return result, start.elapsed_time(end) | |
profile = False # use pytorch profiler, or just simple benchmarking? | |
if profile: | |
# useful docs on pytorch profiler: | |
# - tutorial https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html | |
# - api https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile | |
wait, warmup, active = 5, 5, 5 | |
num_steps = wait + warmup + active | |
with torch.profiler.profile( | |
activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], | |
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1), | |
on_trace_ready=torch.profiler.tensorboard_trace_handler('./bench_log'), | |
record_shapes=True, | |
profile_memory=True, | |
with_stack=True, # incurs an additional overhead, disable if not needed | |
with_flops=True, | |
with_modules=False, # only for torchscript models atm | |
) as prof: | |
for k in range(num_steps): | |
X, Y = get_batch('train') | |
with ctx: | |
logits, loss = model(X, Y) | |
optimizer.zero_grad(set_to_none=True) | |
loss.backward() | |
optimizer.step() | |
lossf = loss.item() | |
print(f"{k}/{num_steps} loss: {lossf:.4f}") | |
prof.step() # notify the profiler at end of each step | |
else: | |
# simple benchmarking | |
torch.cuda.synchronize() | |
model = model.eval() | |
with ctx: | |
X, Y = get_batch('train') | |
_, te = timed(lambda : model(X, Y)) | |
_, te = timed(lambda : model(X, Y)) | |
_, te = timed(lambda : model(X, Y)) | |
_, te = timed(lambda : model(X, Y)) | |
_, te = timed(lambda : model(X, Y)) | |
print("TE forward pass (ms) ", te) | |
model = model.train() | |
for stage, num_steps in enumerate([10, 20]): # burnin, then benchmark | |
t0 = time.time() | |
for k in range(num_steps): | |
X, Y = get_batch('train') | |
with ctx: | |
logits, loss = model(X, Y) | |
optimizer.zero_grad(set_to_none=True) | |
loss.backward() | |
optimizer.step() | |
lossf = loss.item() | |
print(f"{k}/{num_steps} loss: {lossf:.4f}") | |
torch.cuda.synchronize() | |
t1 = time.time() | |
if stage == 1: | |
print(f"time per iteration: {(t1-t0)/num_steps*1000:.4f}ms") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment