Skip to content

Instantly share code, notes, and snippets.

@NaxAlpha
Created April 18, 2023 09:42
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save NaxAlpha/914f2f57ee0d412e016d2ecbbadc7646 to your computer and use it in GitHub Desktop.
Save NaxAlpha/914f2f57ee0d412e016d2ecbbadc7646 to your computer and use it in GitHub Desktop.
a custom gpt-like model that is tiny but can also scale context very long
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cuda as cuda
class NewGELU(nn.Module):
def forward(self, x):
# https://github.com/karpathy/nanoGPT/blob/master/model.py#L19
_sqrt_2_over_pi = math.sqrt(2.0 / math.pi)
_x_pow_3 = torch.pow(x, 3.0)
_tanh_inp = _sqrt_2_over_pi * (x + 0.044715 * _x_pow_3)
return 0.5 * x * (1.0 + torch.tanh(_tanh_inp))
class LayerNorm(nn.Module):
def __init__(self, ndim, bias):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
def forward(self, input):
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
class Attention(nn.Module):
def __init__(self, d_model, heads):
super().__init__()
self.d_model = d_model
self.heads = heads
self.d_k = d_model // heads
self.q = nn.Linear(d_model, d_model)
self.k = nn.Linear(d_model, d_model)
self.v = nn.Linear(d_model, d_model)
self.fc = nn.Linear(d_model, d_model)
def forward(self, x):
B, N, D = x.shape
H = self.heads
dk = self.d_k
q = self.q(x).reshape(B, N, H, dk).permute(0, 2, 1, 3)
k = self.k(x).reshape(B, N, H, dk).permute(0, 2, 1, 3)
v = self.v(x).reshape(B, N, H, dk).permute(0, 2, 1, 3)
with cuda.sdp_kernel(enable_math=False):
a = F.scaled_dot_product_attention(
q.half(), k.half(), v.half(), is_causal=True
).float()
a = a.permute(0, 2, 1, 3).reshape(B, N, D)
return self.fc(a)
def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
# type: (Tensor, float, bool, float, int) -> Tensor
r"""
Samples from the `Gumbel-Softmax distribution`_ and optionally discretizes.
You can use this function to replace "F.gumbel_softmax".
Args:
logits: `[..., num_features]` unnormalized log probabilities
tau: non-negative scalar temperature
hard: if ``True``, the returned samples will be discretized as one-hot vectors,
but will be differentiated as if it is the soft sample in autograd
dim (int): A dimension along which softmax will be computed. Default: -1.
Returns:
Sampled tensor of same shape as `logits` from the Gumbel-Softmax distribution.
If ``hard=True``, the returned samples will be one-hot, otherwise they will
be probability distributions that sum to 1 across `dim`.
.. note::
This function is here for legacy reasons, may be removed from nn.Functional in the future.
.. note::
The main trick for `hard` is to do `y_hard - y_soft.detach() + y_soft`
It achieves two things:
- makes the output value exactly one-hot
(since we add then subtract y_soft value)
- makes the gradient equal to y_soft gradient
(since we strip all other gradients)
Examples::
>>> logits = torch.randn(20, 32)
>>> # Sample soft categorical using reparametrization trick:
>>> F.gumbel_softmax(logits, tau=1, hard=False)
>>> # Sample hard categorical using "Straight-through" trick:
>>> F.gumbel_softmax(logits, tau=1, hard=True)
.. _Gumbel-Softmax distribution:
https://arxiv.org/abs/1611.00712
https://arxiv.org/abs/1611.01144
"""
def _gen_gumbels():
gumbels = -torch.empty_like(logits).exponential_().log()
if torch.isnan(gumbels).sum() or torch.isinf(gumbels).sum():
# to avoid zero in exp output
gumbels = _gen_gumbels()
return gumbels
gumbels = _gen_gumbels() # ~Gumbel(0,1)
gumbels = (logits + gumbels) / tau # ~Gumbel(logits,tau)
y_soft = gumbels.softmax(dim)
if hard:
# Straight through.
index = y_soft.max(dim, keepdim=True)[1]
y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0)
ret = y_hard - y_soft.detach() + y_soft
else:
# Reparametrization trick.
ret = y_soft
return ret
class Discretizer(nn.Module):
def __init__(self, d_model, num_tokens):
super().__init__()
self.vocb = nn.Embedding(num_tokens, d_model, max_norm=1.0)
self.prep = nn.Linear(d_model, num_tokens)
def forward(self, x):
return gumbel_softmax(self.prep(x), tau=1, dim=-1) @ self.vocb.weight
class Dictionary(nn.Sequential):
def __init__(self, d_model, num_tokens):
super().__init__(
nn.Linear(d_model, d_model),
NewGELU(),
LayerNorm(d_model, bias=False),
nn.Linear(d_model, num_tokens),
nn.Softmax(dim=-1),
nn.Linear(num_tokens, d_model),
)
class Block(nn.Module):
def __init__(self, d_model, heads, d_ff):
super().__init__()
self.norm1 = LayerNorm(d_model, bias=False)
self.norm2 = LayerNorm(d_model, bias=False)
self.dict = Dictionary(d_model, d_ff)
self.attn = Attention(d_model, heads)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.dict(self.norm2(x))
return x
class GPTa(nn.Module):
def __init__(self, d_model, heads, d_ff, num_layers, num_tokens, max_seq_len):
super().__init__()
self.d_model = d_model
self.num_layers = num_layers
self.num_tokens = num_tokens
self.max_seq_len = max_seq_len
self.token_emb = nn.Embedding(num_tokens, d_model, max_norm=1.0)
self.pos_emb = nn.Embedding(max_seq_len, d_model, max_norm=1.0)
self.blocks = nn.ModuleList(
[
Block(
d_model,
heads,
d_ff,
)
for _ in range(num_layers)
]
)
self.last = nn.Sequential(
LayerNorm(d_model, bias=False),
nn.Linear(d_model, num_tokens),
)
def param_sum(self, include_embeddings=False):
params = sum(p.numel() for p in self.parameters() if p.requires_grad)
emb_params = list(self.token_emb.parameters()) + list(self.pos_emb.parameters())
emb_params = sum(p.numel() for p in emb_params if p.requires_grad)
return params - emb_params if not include_embeddings else params
def forward(self, x):
B, N = x.shape
assert N <= self.max_seq_len, "Sequence length exceeds model capacity"
token_emb = self.token_emb(x)
pos_emb = self.pos_emb(torch.arange(N, device=x.device))
x = token_emb + pos_emb
for i, block in enumerate(self.blocks):
x = block(x)
l = self.last(x)
return l
def test():
seq = 2**10
model = GPTa(
d_model=256,
heads=8,
d_ff=1024,
num_layers=16,
num_tokens=10000,
max_seq_len=seq,
).cuda()
print(model.param_sum())
print(model.param_sum(include_embeddings=True))
x = torch.randint(0, 10000, (1, seq), dtype=torch.long).cuda()
logits = model(x)
logits.sum().backward()
print(logits.shape)
if __name__ == "__main__":
test()
input()
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, IterableDataset
import wandb
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer
from model import GPTa
WANDB_STYLE = """
<style>
html, body {
padding: 0;
margin: 0;
width: 100%;
height: 100%;
}
p {
font-family: 'Verdana', sans-serif;
}
</style>
"""
class DatasetWrapper(IterableDataset):
def __init__(self, max_tokens):
self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m")
self.max_tokens = max_tokens
def __iter__(self):
buffer = []
for sample in load_dataset(
"EleutherAI/the_pile_deduplicated",
name="all",
split="train",
streaming=True,
).shuffle(buffer_size=10_000):
buffer += self.tokenizer(sample["text"])["input_ids"]
buffer += [self.tokenizer.eos_token_id]
while len(buffer) > self.max_tokens:
yield torch.tensor(buffer[: self.max_tokens])
buffer = buffer[self.max_tokens :]
class Trainer:
def __init__(self):
self.max_tokens = 2**13
self.grad = 64
self.step = 0
self.dataset = DatasetWrapper(self.max_tokens)
self.tokenizer = self.dataset.tokenizer
self.loader = DataLoader(
self.dataset,
batch_size=2,
num_workers=8,
)
self.scaler = torch.cuda.amp.GradScaler()
self.model = model = GPTa(
d_model=256,
heads=8,
d_ff=1024,
num_layers=16,
num_tokens=50304,
max_seq_len=self.max_tokens,
).cuda()
print("Params:", model.param_sum())
print("Params (incl. embeddings):", model.param_sum(include_embeddings=True))
self.opt = optim.AdamW(
params=model.parameters(),
lr=6e-4,
weight_decay=1e-1,
betas=(0.9, 0.95),
fused=True,
)
self.model = torch.compile(model)
def train_step(self, batch):
batch = batch.cuda()
x, y = batch[:, :-1], batch[:, 1:]
with torch.autocast(device_type="cuda", enabled=True):
z = self.model(x)
y = y.reshape(-1)
loss = F.cross_entropy(z.view(-1, z.shape[-1]), y)
self.scaler.scale(loss / self.grad).backward()
return loss
def train(self):
wandb.init(
project="gpt-a",
entity="_",
)
prog = tqdm(self.loader)
self.opt.zero_grad()
for i, batch in enumerate(prog):
self.step = i + 1
loss = self.train_step(batch)
prog.set_description(f"loss: {loss.item():.3f}")
wandb.log(
{
"loss": loss.item(),
},
step=i,
)
if i % self.grad == 0:
self.scaler.step(self.opt)
self.scaler.update()
self.opt.zero_grad()
if i % 500 == 0:
torch.save(self.model.state_dict(), "model.pt")
if __name__ == "__main__":
trainer = Trainer()
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment