Skip to content

Instantly share code, notes, and snippets.

@NaxAlpha
Last active August 4, 2023 04:47
Show Gist options
  • Save NaxAlpha/0b63348cd19395779cd4b021888c2fb4 to your computer and use it in GitHub Desktop.
Save NaxAlpha/0b63348cd19395779cd4b021888c2fb4 to your computer and use it in GitHub Desktop.
Train a semantic text compressor, potentially useful for very long context language modeling
import random
from time import sleep
from functools import partial
from threading import Thread, Lock
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn
from torch.utils.data import DataLoader, IterableDataset
import wandb
from tqdm import tqdm
from datasets import load_dataset
from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXAttention
MODEL_NAME = "EleutherAI/pythia-1.4b-deduped-v0"
WANDB_STYLE = """
<style>
html, body {
padding: 0;
margin: 0;
width: 100%;
height: 100%;
}
p {
font-family: 'Verdana', sans-serif;
}
table {
border-collapse: collapse;
width: 100%;
}
td, th {
border: 1px solid #999999;
text-align: left;
padding: 8px;
}
tr:nth-child(even) {
background-color: #eeeeee;
}
pre {
white-space: pre-wrap;
}
</style>
"""
class LoRALinear(nn.Module):
def __init__(self, finp, fout, r=4):
super().__init__()
self.finp = finp
self.fout = fout
self.r = r
self.fc1 = nn.Linear(finp, r)
self.fc2 = nn.Linear(r, fout)
self.fc2.weight.data.zero_()
self.fc2.bias.data.zero_()
def forward(self, x):
return self.fc2(self.fc1(x))
class LoRAWrapper(nn.Module):
def __init__(self, main, lora):
super().__init__()
self.main = main
self.lora = lora
def forward(self, x):
return self.main(x) + self.lora(x)
class Conceptor(nn.Module):
def __init__(self):
super().__init__()
self.base_model = GPTNeoXForCausalLM.from_pretrained(
MODEL_NAME,
)
self.hh = hh = self.base_model.config.hidden_size
self.embeddings = nn.Embedding(2, hh)
self.base_model.requires_grad_(False)
self.loras = self.make_lora()
def save(self, path):
sd = dict(
loras=self.loras.state_dict(),
embeddings=self.embeddings.state_dict(),
)
torch.save(sd, path)
def make_lora(self):
layers = []
for module in self.base_model.modules():
if isinstance(module, GPTNeoXAttention):
lora = LoRALinear(self.hh, 3 * self.hh)
layers.append(lora)
module.query_key_value = LoRAWrapper(module.query_key_value, lora)
lora = LoRALinear(self.hh, self.hh)
layers.append(lora)
module.dense = LoRAWrapper(module.dense, lora)
return nn.ModuleList(layers)
def _encode(self, tokens, sizes):
B, T = tokens.size()
indices = torch.arange(B, device=tokens.device)
embeddings = self.base_model.gpt_neox.embed_in(tokens)
# replace size'th token with self.embeddings(0)
embeddings[indices, sizes] = self.embeddings.weight[0]
output = self.base_model.gpt_neox(inputs_embeds=embeddings)
# take size'th token from each sequence
context = output.last_hidden_state[indices, sizes]
return context, embeddings
def forward(self, tokens, sizes):
B, T = tokens.size()
context, embeddings = self._encode(tokens, sizes)
# combine context and self.embeddings(1) and tokens
emb_token = self.embeddings.weight[1].unsqueeze(0).expand(B, -1)
context = torch.cat(
[context[:, None], emb_token[:, None], embeddings[:, :-2]], dim=1
)
# compute logits
logits = self.base_model(inputs_embeds=context).logits
logits = logits[:, 1:].reshape(-1, logits.size(-1))
# compute loss
targets = tokens[:, :-1].contiguous().view(-1)
loss = F.cross_entropy(logits, targets, reduction="none")
loss = loss.reshape(B, -1)
loss_mask = torch.arange(T - 1, device=tokens.device) < sizes[:, None]
loss = loss.masked_fill(~loss_mask, 0.0)
loss = loss.sum(dim=1) / sizes.float()
loss = loss.mean()
return loss
def encode(self, tokens, sizes):
return self._encode(tokens, sizes)[0]
@torch.no_grad()
def sample(self, context, max_tokens=128, temperature=1.0):
B = context.size(0)
emb_token = self.embeddings.weight[1].unsqueeze(0).expand(B, -1)
context = torch.cat([context[:, None], emb_token[:, None]], dim=1)
output_tokens = []
for _ in range(max_tokens):
logits = self.base_model(inputs_embeds=context).logits
logits = logits[:, -1, :] / temperature
token = torch.multinomial(logits.softmax(dim=-1), num_samples=1)
output_tokens.append(token)
token_emb = self.base_model.gpt_neox.embed_in(token)
context = torch.cat([context, token_emb], dim=1)
return torch.cat(output_tokens, dim=-1)
class DatasetWrapper(IterableDataset):
def __init__(self, min_tokens=1, max_tokens=32):
self.tokenizer = GPTNeoXTokenizerFast.from_pretrained(MODEL_NAME)
self.min_tokens = min_tokens
self.max_tokens = max_tokens
self._buffer = []
self._min_buffer = 10_000
self._max_buffer = 20_000
self._lock = None
self._thread = None
def _worker(self):
temp_buffer = []
for sample in load_dataset(
"EleutherAI/the_pile_deduplicated",
split="train",
streaming=True,
).shuffle(buffer_size=1000):
text = sample["text"] + "<|endofdoc|>"
tokens = self.tokenizer.encode(text)
temp_buffer.extend(tokens)
# crop into chunks
while len(temp_buffer) >= self.max_tokens:
size = random.randrange(self.min_tokens, self.max_tokens - 1)
crop = temp_buffer[:size] + [self.tokenizer.eos_token_id] * 2
with self._lock:
self._buffer.append(torch.tensor(crop))
temp_buffer = temp_buffer[size:]
sleep(0.001)
# wait for buffer to drain
while len(self._buffer) >= self._max_buffer:
sleep(0.1)
def __iter__(self):
self._lock = Lock()
self._thread = Thread(target=self._worker, daemon=True)
self._thread.start()
while True:
while len(self._buffer) < self._min_buffer:
sleep(0.1)
with self._lock:
idx = random.randrange(len(self._buffer))
sample = self._buffer.pop(idx)
yield sample
def dl_collate_fn(batch, pad_token_id):
lengths = [t.size(0) for t in batch]
tokens = rnn.pad_sequence(
batch,
batch_first=True,
padding_value=pad_token_id,
)
return tokens, torch.tensor(lengths) - 1
class Trainer:
def __init__(self):
self.dataset = DatasetWrapper()
self.loader = DataLoader(
self.dataset,
batch_size=32,
num_workers=8,
collate_fn=partial(
dl_collate_fn,
pad_token_id=self.dataset.tokenizer.eos_token_id,
),
)
self.model = model = Conceptor().cuda()
print("Model parameters:", sum(p.numel() for p in model.parameters()))
print(
"Trainable parameters:",
sum(p.numel() for p in model.parameters() if p.requires_grad),
)
self.opt = optim.Adam(
params=model.parameters(),
lr=6e-5,
fused=True,
)
# self.model = torch.compile(model)
def train_step(self, tokens, lengths):
self.opt.zero_grad()
loss = self.model(tokens.cuda(), lengths.cuda())
loss.backward()
self.opt.step()
return loss
def _detokenize(self, tokens):
eos = self.dataset.tokenizer.eos_token_id
output = []
for tkn in tokens:
# stop at first EOS in token list
idx = tkn.index(eos) if eos in tkn else len(tkn)
text = self.dataset.tokenizer.decode(tkn[:idx])
output.append(text)
return output
def generate(self, tokens, sizes):
self.model.eval()
mem = self.model.encode(tokens, sizes)
out = self.model.sample(mem, tokens.size(1), temperature=0.1).tolist()
self.model.train()
original = self._detokenize(tokens.tolist())
generated = self._detokenize(out)
table = "<table><tr><th>Original</th><th>Generated</th></tr>"
for o, g in zip(original, generated):
table += f"<tr><td><pre>{o}</pre></td><td><pre>{g}</pre></td></tr>"
table += "</table>"
return table
def train(self):
wandb.init(
project="conceptor",
entity="_",
)
sd = torch.load("model-v3.pt")
self.model.load_state_dict(sd)
del sd
prog = tqdm(self.loader)
for i, (tokens, lengths) in enumerate(prog):
loss = self.train_step(tokens, lengths)
prog.set_description(f"loss: {loss.item():.3f}")
wandb.log(
{
"loss": loss.item(),
"size": lengths.float().mean().item(),
"avgb": tokens.size(1),
},
step=i,
)
if i % 200 == 0:
table = self.generate(tokens.cuda(), lengths.cuda())
wandb.log(dict(diff=wandb.Html(WANDB_STYLE + table)), step=i)
self.model.save("model.pt")
if __name__ == "__main__":
trainer = Trainer()
trainer.train()
This file has been truncated, but you can view the full file.
View raw

(Sorry about that, but we can’t show files that are this big right now.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment