Skip to content

Instantly share code, notes, and snippets.

@NaxAlpha
Last active January 23, 2024 02:28
Show Gist options
  • Save NaxAlpha/13c80fd0df6f57958e147daec3d90485 to your computer and use it in GitHub Desktop.
Save NaxAlpha/13c80fd0df6f57958e147daec3d90485 to your computer and use it in GitHub Desktop.
Softformer - An Attention-free, softmax based transformer for causal language modeling.
import torch
import torch.nn as nn
import torch.nn.functional as F
def cum_softmax(x, dim=1): # <- main novelty
z = x.exp()
d = z.cumsum(dim)
return z / d
class SoftBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.proj = nn.Linear(dim, dim)
self.soft = nn.Linear(dim, dim)
def forward(self, x):
# x: [batch, seq, dim]
x = self.norm(x)
p = self.proj(x)
s = self.soft(x)
s = cum_softmax(s, dim=1)
y = p * s
return y.cumsum(dim=1)
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, 2 * hidden_dim),
nn.GLU(),
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, dim),
nn.GELU(),
)
def forward(self, x):
return self.net(x)
class Softformer(nn.Module):
def __init__(self, dim, depth, hidden_dim=...):
super().__init__()
if hidden_dim is ...:
hidden_dim = dim * 2
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers += [
SoftBlock(dim),
FeedForward(dim, hidden_dim),
]
def forward(self, x):
# x: [batch, seq, dim]
for layer in self.layers:
x = x + layer(x)
return x
class SoftRegressor(nn.Module):
def __init__(self, max_ctx, vocab_size, emb_dim, depth, hidden_dim=...):
super().__init__()
self.emb = nn.Embedding(vocab_size, emb_dim, max_norm=1)
self.pos = nn.Embedding(max_ctx, emb_dim, max_norm=1)
self.net = Softformer(emb_dim, depth, hidden_dim)
self.end = nn.Sequential(
nn.Linear(emb_dim, emb_dim),
nn.LayerNorm(emb_dim),
)
def forward(self, x, y=None):
# x: [batch, seq], y: [batch, seq]
_, seq = x.shape
x = self.emb(x) + self.pos(torch.arange(seq, device=x.device))
x = self.net(x)
x = self.end(x)
x = x @ self.emb.weight.t()
if y is None:
return x
loss = F.cross_entropy(x.view(-1, x.shape[-1]), y.reshape(-1))
return loss, x
@torch.no_grad()
def generate(self, x, max_len=100, temp=1.0):
# x: [batch, seq]
is_train = self.training
self.eval()
while x.shape[1] < max_len:
y = self.forward(x)
y = y[:, -1, :] / temp
y = y.softmax(dim=-1)
y = torch.multinomial(y, 1)
x = torch.cat([x, y], dim=1)
self.train(is_train)
return x
import time
import random
import torch
from torch.utils.data import DataLoader
import wandb
from tqdm import tqdm
from datasets import load_dataset
from transformers import GPT2TokenizerFast
from model import SoftRegressor
class Trainer:
def __init__(self):
self.dataset = load_dataset(
"the_pile",
name="all",
split="train",
streaming=True,
).shuffle(buffer_size=1000)
self.tokenizer: GPT2TokenizerFast = GPT2TokenizerFast.from_pretrained("gpt2")
self.max_tokens = 512
self.dataset = self.dataset.map(
self.tokenize,
batched=True,
batch_size=64,
)
self.loader = DataLoader(
self.dataset,
batch_size=8,
num_workers=8,
)
self.model = model = SoftRegressor(
max_ctx=self.max_tokens,
vocab_size=self.tokenizer.vocab_size,
emb_dim=1024,
depth=24,
).cuda()
self.opt = torch.optim.Adam(
model.parameters(),
lr=1e-4,
weight_decay=1e-2,
)
num_params = sum(p.numel() for p in model.parameters())
emb_params = list(model.emb.parameters()) + list(model.pos.parameters())
emb_params = sum(p.numel() for p in emb_params)
non_emb_params = num_params - emb_params
print(f"num params: {num_params}")
print(f"emb params: {emb_params}")
print(f"non emb params: {non_emb_params}")
def tokenize(self, examples):
N = len(examples["text"])
out = self.tokenizer(examples["text"])
# join with eos
res = []
for inp in out["input_ids"]:
res += inp + [self.tokenizer.eos_token_id]
# sample len(examples) sequences of length max_tokens
exp = []
for i in range(N):
j = random.randint(0, len(res) - self.max_tokens - 1)
exp.append(res[j : j + self.max_tokens + 1])
return {"input_ids": torch.tensor(exp)}
def train(self):
wandb.init(
project="softformer",
entity="nax-autify",
)
prog = tqdm(self.loader)
for i, batch in enumerate(prog):
batch = batch["input_ids"].cuda()
self.opt.zero_grad()
loss, _ = self.model(batch[:, :-1], batch[:, 1:])
loss.backward()
self.opt.step()
prog.set_description(f"loss: {loss.item():.3f}")
wandb.log({"loss": loss.item()}, step=i)
if i % 100 == 0:
torch.save(self.model.state_dict(), "model.pt")
if i % 1000 == 0:
x = torch.tensor([[self.tokenizer.eos_token_id]] * 8).cuda()
t0 = time.time()
y = self.model.generate(x, max_len=self.max_tokens).tolist()
t1 = time.time()
t = [self.tokenizer.decode(z) for z in y]
t = "<hr>".join(f"<p>{c}</p>" for c in t)
html = (
"""
<style>
html, body {
padding: 0;
margin: 0;
width: 100%;
height: 100%;
}
p {
font-family: 'Verdana', sans-serif;
}
</style>
"""
+ t
)
wandb.log({"samples": wandb.Html(html)}, step=i)
print(f"Generated in {t1-t0:.3f}s")
if __name__ == "__main__":
trainer = Trainer()
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment