Skip to content

Instantly share code, notes, and snippets.

@Laeeth
Forked from NaxAlpha/long_gpt.py
Created April 13, 2023 04:17
Show Gist options
  • Star 4 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Laeeth/ecaa9cfb982b58fe8c0680057b0af50d to your computer and use it in GitHub Desktop.
Save Laeeth/ecaa9cfb982b58fe8c0680057b0af50d to your computer and use it in GitHub Desktop.
Training script for LongGPT; Fine-tunes GPT-2 (335M) on The Pile Dataset with a context size of 8k tokens. (requires > 16GB RAM)
import time
from contextlib import suppress
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cuda as cuda
from torch.utils.data import DataLoader, IterableDataset
import wandb
from tqdm import tqdm
from datasets import load_dataset
from transformers import GPT2TokenizerFast
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Attention
_attn_orig = GPT2Attention._attn
WANDB_STYLE = """
<style>
html, body {
padding: 0;
margin: 0;
width: 100%;
height: 100%;
}
p {
font-family: 'Verdana', sans-serif;
}
</style>
"""
# patch GPT2Attention to use flash_sdp, disable it when doing the inference
def _attn_wrapper(self, query, key, value, attention_mask=None, head_mask=None):
if head_mask is not None:
raise NotImplementedError("head_mask is not implemented for flash_sdp")
is_causal = attention_mask is None
with cuda.sdp_kernel(
enable_flash=True,
enable_math=False,
enable_mem_efficient=False,
):
attn_out = F.scaled_dot_product_attention(
query=query.half(),
key=key.half(),
value=value.half(),
is_causal=is_causal,
attn_mask=attention_mask,
dropout_p=self.attn_dropout.p,
).float()
return attn_out, None
def closest_power_of_2(x):
return 2 ** (x - 1).bit_length()
def make_model(pretrained_name, max_tokens):
model = GPT2LMHeadModel.from_pretrained(pretrained_name).cuda()
GPT2Attention._attn = _attn_wrapper
model.config.update(
dict(
n_ctx=max_tokens,
n_positions=max_tokens,
)
)
# patch model embeddings
emb = model.transformer.wpe.weight.data
wpe = nn.Embedding(max_tokens, emb.shape[1])
wpe.weight.data = emb.repeat(max_tokens // emb.shape[0], 1)
model.transformer.wpe = wpe
# also increase mask size
for block in model.transformer.h:
block.attn.bias.data = (
torch.tril(torch.ones((max_tokens, max_tokens), dtype=torch.bool))
.view(1, 1, max_tokens, max_tokens)
.cuda()
)
return model
class DatasetWrapper(IterableDataset):
def __init__(self, max_tokens=2**12):
self.tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
self.max_tokens = max_tokens
def __iter__(self):
buffer = []
for sample in load_dataset(
"the_pile",
name="all",
split="train",
streaming=True,
).shuffle(buffer_size=10_000):
buffer += self.tokenizer(sample["text"])["input_ids"]
buffer += [self.tokenizer.eos_token_id]
while len(buffer) > self.max_tokens:
yield torch.tensor(buffer[: self.max_tokens])
buffer = buffer[self.max_tokens :]
class Trainer:
def __init__(self):
self.max_tokens = 2**13
self.grad = 1
self.step = 0
self.dataset = DatasetWrapper(self.max_tokens)
self.tokenizer = self.dataset.tokenizer
self.loader = DataLoader(
self.dataset,
batch_size=1,
num_workers=8,
)
self.scaler = torch.cuda.amp.GradScaler()
self.model = model = make_model("gpt2-medium", self.max_tokens)
self.opt = optim.Adam(
params=model.parameters(),
lr=5e-6,
weight_decay=1e-1,
betas=(0.9, 0.95),
fused=True,
)
self.model = torch.compile(model)
def train_step(self, batch):
batch = batch.cuda()
with torch.autocast(device_type="cuda", enabled=True):
loss = self.model(batch, labels=batch).loss
loss = loss / self.grad
self.scaler.scale(loss).backward()
return loss
def generate_samples(self, n_samples=8):
GPT2Attention._attn = _attn_orig # back to faster but more memory consuming
model = self.model
x = torch.tensor([[self.tokenizer.eos_token_id]] * n_samples).cuda()
t0 = time.time()
model.eval()
y = model.generate(
inputs=x,
max_length=self.max_tokens,
do_sample=True,
).tolist()
model.train()
t1 = time.time()
t = [self.tokenizer.decode(z) for z in y]
t = "<hr>".join(f"<p>{c}</p>" for c in t)
html = WANDB_STYLE + t
wandb.log({"samples": wandb.Html(html)}, step=self.step)
print(f"Generated in {t1-t0:.3f}s")
GPT2Attention._attn = _attn_wrapper
def train(self):
wandb.init(
project="long-gptx",
entity="_",
)
prog = tqdm(self.loader)
self.opt.zero_grad()
for i, batch in enumerate(prog):
self.step = i + 1
loss = self.train_step(batch)
prog.set_description(f"loss: {loss.item():.3f}")
wandb.log(
{
"loss": loss.item(),
"grad": self.grad,
},
step=i,
)
if i % self.grad == 0:
self.scaler.step(self.opt)
self.scaler.update()
self.opt.zero_grad()
self.grad = max(1, closest_power_of_2(i + 1) // 32)
# if i % 1000 == 0:
# with suppress(Exception):
# self.model.save_pretrained(
# "_",
# push_to_hub=True,
# max_shard_size="500MB",
# )
if i % 1000 == 0:
self.generate_samples(16)
if __name__ == "__main__":
trainer = Trainer()
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment