Skip to content

Instantly share code, notes, and snippets.

@ericflo
Last active April 30, 2023 03:58
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ericflo/5b385303d589172d86512f0f38f810a3 to your computer and use it in GitHub Desktop.
Save ericflo/5b385303d589172d86512f0f38f810a3 to your computer and use it in GitHub Desktop.
import os
import json
import random
import textwrap
import re
import math
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset, IterableDataset
from torch.optim.adamw import AdamW
from torch.optim.sgd import SGD
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import datasets
from tqdm import tqdm
PRECISION = torch.bfloat16
USERNAME = "exampleuser"
class LoRA(nn.Module):
def __init__(self, dim, dim_out, r=8, alpha=None):
super().__init__()
alpha = alpha if alpha is not None else r
self.scale = alpha / r
self.A = nn.Parameter(
nn.init.kaiming_uniform_(torch.randn(dim, r).to(PRECISION), a=math.sqrt(5))
)
self.B = nn.Parameter(torch.zeros(r, dim_out).to(PRECISION))
@property
def weight(self):
return (self.A @ self.B) * self.scale
def forward(self, x):
return x @ self.weight
class LoRAForward(nn.Module):
def __init__(self, original_layer, lora):
super().__init__()
self.original_layer = original_layer
self.lora = lora
def forward(self, x, *args, **kwargs):
output = self.original_layer(x, *args, **kwargs)
prev_output = output[0]
lora_output = self.lora(x).view(prev_output.shape)
merged_output = prev_output + lora_output
return merged_output, *output[1:]
class TextDataset(Dataset):
def __init__(self, texts, tokenizer, max_length=280):
self.texts = texts
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
input_text = self.texts[idx]
tokenized = self.tokenizer(
input_text,
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
)
return tokenized
class ThePileDataset(IterableDataset):
def __init__(self, tokenizer, max_length=280, buffer_size=10_000, seed=42):
self.the_pile = datasets.load_dataset(
"EleutherAI/the_pile_deduplicated",
split="train",
streaming=True,
).shuffle(buffer_size=buffer_size, seed=seed)
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
# return len(self.the_pile)
return 134318121
def __iter__(self):
for row in self.the_pile:
yield self.tokenizer(
row["text"],
truncation=True,
max_length=self.max_length,
padding="max_length",
return_tensors="pt",
)
def insert_lora_layers(model, config, r=8, alpha=None):
for i in range(config.num_hidden_layers):
layer = model.gpt_neox.layers[i]
# Get input and output dimensions for the LoRA layer
dim_in = config.hidden_size
dim_out = config.hidden_size
lora = LoRA(dim=dim_in, dim_out=dim_out, r=r, alpha=alpha)
# Replace the existing layer with a LoRAForward container containing the original layer and the new LoRA layer
model.gpt_neox.layers[i] = LoRAForward(layer, lora)
def extract_lora_weights(model):
lora_weights = {}
for idx, layer in enumerate(model.gpt_neox.layers):
if isinstance(layer, LoRAForward):
lora_weights[f"lora_{idx}_A"] = layer.lora.A
lora_weights[f"lora_{idx}_B"] = layer.lora.B
return lora_weights
def load_lora_weights(model, lora_weights):
for idx, layer in enumerate(model.gpt_neox.layers):
if isinstance(layer, LoRAForward):
layer.lora.A = lora_weights[f"lora_{idx}_A"]
layer.lora.B = lora_weights[f"lora_{idx}_B"]
def freeze_model(model):
for param in model.parameters():
param.requires_grad = False
def generate_sample(
model,
tokenizer,
prompt,
min_length=50,
max_length=280,
top_k=50,
top_p=0.95,
temperature=0.8,
):
with torch.no_grad():
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
generated = model.generate(
input_ids,
max_length=max_length,
min_length=min_length,
num_return_sequences=1,
no_repeat_ngram_size=2,
do_sample=True,
top_k=top_k,
top_p=top_p,
temperature=temperature,
pad_token_id=tokenizer.eos_token_id,
)
resp = tokenizer.decode(generated[0], skip_special_tokens=True)
return resp
def print_demo_text(model, tokenizer):
temperatures = [0.7, 0.8, 0.9]
temperature = random.choice(temperatures)
temperatures.remove(temperature)
generated_text = generate_sample(
model, tokenizer, f"{USERNAME}: I just", temperature=temperature
)
print(
f"Generated text 1 [{temperature}]:\n{textwrap.indent(textwrap.fill(generated_text), ' ')}"
)
temperature = random.choice(temperatures)
temperatures.remove(temperature)
generated_text = generate_sample(
model, tokenizer, f"{USERNAME}: I think", temperature=temperature
)
print(
f"Generated text 2 [{temperature}]:\n{textwrap.indent(textwrap.fill(generated_text), ' ')}"
)
temperature = random.choice(temperatures)
temperatures.remove(temperature)
generated_text = generate_sample(
model, tokenizer, f"{USERNAME}: When it comes to", temperature=temperature
)
print(
f"Generated text 3 [{temperature}]:\n{textwrap.indent(textwrap.fill(generated_text), ' ')}"
)
print("-----")
if __name__ == "__main__":
# Load the pretrained model and config
model_name = "EleutherAI/pythia-1.4b-deduped"
# model_name = "EleutherAI/pythia-160m-deduped"
# model_name = "EleutherAI/pythia-70m-deduped"
config = AutoConfig.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name, config=config, torch_dtype=PRECISION
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
should_train = True
# Freeze the pretrained model parameters
freeze_model(model)
# Insert LoRA layers
rank = 8
lora_alpha = 16.0
insert_lora_layers(model, config, r=rank, alpha=lora_alpha)
device = torch.device("cuda")
# device = torch.device("mps")
model = model.to(device)
batch_accumulation = 48
# lr = 8e-4
# lr = 0.0064
lr = 2e-4
# lr = 1e-4
# lr = 5e-5
# lr = 1e-5
start_epoch, num_epochs = 0, 200
batch_size = 20
model_save_dir = (
f"model_checkpoints_{re.sub('[^a-zA-Z0-9]+', '_', model_name.split('/')[-1])}"
)
os.makedirs(model_save_dir, exist_ok=True)
latest_model = (
sorted(os.listdir(model_save_dir), reverse=True)[0]
if os.listdir(model_save_dir)
else None
)
if latest_model:
model_path = os.path.join(model_save_dir, latest_model)
start_epoch = (
int(os.path.splitext(os.path.basename(model_path))[0].split("_")[-2]) + 1
)
lora_weights = torch.load(model_path)
load_lora_weights(model, lora_weights)
print(
f"Loaded LoRA weights from: {model_path} making new Start Epoch: {start_epoch}"
)
print_demo_text(model, tokenizer)
if should_train:
with open("tweets.json", "r", encoding="utf-8") as f:
tweets = [f'{USERNAME}: {t["tweet"]["full_text"]}' for t in json.load(f)]
random.seed(42)
random.shuffle(tweets)
split_idx = int(len(tweets) * 0.9)
texts, tests = tweets[:split_idx], tweets[split_idx:]
dataset = TextDataset(texts, tokenizer)
# dataset = ThePileDataset(tokenizer)
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=not isinstance(dataset, IterableDataset),
)
optimizer = AdamW(model.parameters(), lr=lr)
# optimizer = SGD(model.parameters(), lr=lr)
model.train()
accum_loss = 0.0
accum_count = 0
prev_loss = 0.0
n_step = 0
for epoch in range(start_epoch, num_epochs):
# Wrap the dataloader with tqdm for a progress bar
progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}")
for i, batch in enumerate(progress_bar):
n_step += 1
input_ids = batch["input_ids"].squeeze().to(device)
attention_mask = batch["attention_mask"].squeeze().to(device)
outputs = model(
input_ids=input_ids, attention_mask=attention_mask, labels=input_ids
)
loss = outputs.loss / batch_accumulation
loss.backward()
accum_loss += loss.item()
accum_count += 1
if ((i + 1) % batch_accumulation == 0) or (i + 1 == len(dataloader)):
optimizer.step()
optimizer.zero_grad()
prev_loss += accum_loss - prev_loss
accum_loss -= accum_loss
accum_count -= accum_count
# print_demo_text(model, tokenizer)
# model_save_path = os.path.join(
# model_save_dir,
# f"epoch_{str(epoch).zfill(4)}_{str(n_step).zfill(10)}.pt",
# )
# torch.save(extract_lora_weights(model), model_save_path)
# print(f"LoRA weights saved at: {model_save_path}")
progress_bar.set_postfix(
{
"Prev Loss": prev_loss,
"Loss": (
(accum_loss * batch_accumulation)
/ max(float(accum_count), 1.0)
),
"Accum": (batch_accumulation - ((i + 1) % batch_accumulation)),
}
)
# Save the LoRA weights after each epoch
model_save_path = os.path.join(
model_save_dir,
f"epoch_{str(epoch).zfill(4)}_{str(n_step).zfill(10)}.pt",
)
lora_weights = extract_lora_weights(model)
torch.save(lora_weights, model_save_path)
print(f"LoRA weights saved at: {model_save_path}")
print_demo_text(model, tokenizer)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment