Skip to content

Instantly share code, notes, and snippets.

@maxidl
Created August 27, 2023 07:29
Show Gist options
  • Save maxidl/0f76110b8caffa70afbbbe3282e90b28 to your computer and use it in GitHub Desktop.
Save maxidl/0f76110b8caffa70afbbbe3282e90b28 to your computer and use it in GitHub Desktop.
import argparse
import copy
import torch
import datasets as hfds
import transformers
from tqdm.auto import tqdm
import wandb
args = argparse.Namespace()
args.seed = 42
args.run_name = "run_0"
args.model_name = "meta-llama/Llama-2-7b-hf"
args.dataset = "teknium/GPT4-LLM-Cleaned"
args.eval_size = 1000
args.dtype=torch.bfloat16
# args.dtype=torch.float32
args.model_max_length=512
args.train_batch_size=8
args.gradient_accumulation_steps=16
args.eval_batch_size=16
args.eval_steps = 50
args.lr = 2e-5
args.num_epochs = 5
args.num_workers = 4
args.device = "cuda"
# args.device = "cpu"
use_autocast = args.dtype != torch.float32
print(f"use_autocast: {use_autocast}")
transformers.set_seed(args.seed)
# ======================== setup wandb ======================
run = wandb.init(
project="minimal-finetuning",
name=args.run_name,
config=vars(args)
)
# ======================== setup dataset ======================
IGNORE_INDEX = -100
PROMPT_FORMAT = (
"""Below is an instruction that describes a task. """
"""Write a response that appropriately completes the request.\n\n"""
"""### Instruction:\n{instruction}\n\n### Response:"""
)
PROMPT_WITH_INPUT_FORMAT = (
"""Below is an instruction that describes a task, paired with an input that provides further context. """
"""Write a response that appropriately completes the request.\n\n"""
"""### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"""
)
TARGET_FORMAT = """{output}{eos_token}"""
def _preprocess_train_example(example, eos_token):
if example['input']:
prompt_text = PROMPT_WITH_INPUT_FORMAT.format_map(example)
else:
prompt_text = PROMPT_FORMAT.format_map(example)
input_text = prompt_text + TARGET_FORMAT.format_map({**example, "eos_token": eos_token})
return {"prompt_text": prompt_text, "input_text": input_text}
ds = hfds.load_dataset(args.dataset)['train']
ds = ds.train_test_split(args.eval_size, seed=args.seed)
ds = hfds.DatasetDict({'train': ds['train'], 'eval': ds['test']})
print(ds)
tokenizer = transformers.LlamaTokenizer.from_pretrained(args.model_name, model_max_length=args.model_max_length, use_fast=False)
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
ds = ds.map(_preprocess_train_example, batched=False, desc="preprocessing", num_proc=args.num_workers, fn_kwargs={"eos_token": tokenizer.eos_token})
# ======================== setup dataloaders======================
def collate_fn(examples):
prompt_text_enc = tokenizer(
[example["prompt_text"] for example in examples],
return_tensors="pt",
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
)
input_text_enc = tokenizer(
[example["input_text"] for example in examples],
return_tensors="pt",
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
)
labels = copy.deepcopy(input_text_enc["input_ids"])
for i in range(len(examples)):
num_prompt_tokens = prompt_text_enc["input_ids"][i].ne(tokenizer.pad_token_id).sum()
labels[i][:num_prompt_tokens] = IGNORE_INDEX # ignore all tokens in the prompt
labels[i][input_text_enc["attention_mask"][i] == 0] = IGNORE_INDEX # ignore all pad tokens
return {**input_text_enc, "labels": labels}
dl_train = torch.utils.data.DataLoader(ds['train'], batch_size=args.train_batch_size, collate_fn=collate_fn, num_workers=args.num_workers, shuffle=True, drop_last=True)
dl_eval= torch.utils.data.DataLoader(ds['eval'], batch_size=args.eval_batch_size, collate_fn=collate_fn, num_workers=args.num_workers)
# ======================== setup model ======================
model = transformers.LlamaForCausalLM.from_pretrained(
args.model_name,
torch_dtype=args.dtype,
device_map='auto'
)
if use_autocast:
for name, module in model.named_modules():
if ('norm' in name) or ('embed' in name):
# print(f'using float32: {name}')
module.to(torch.float32)
# ======================== setup eval loop ======================
def eval():
with torch.inference_mode():
eval_losses = []
for batch in tqdm(dl_eval, desc="eval_step", leave=False, position=1):
batch = {k: v.to(device) for k, v in batch.items()}
with torch.cuda.amp.autocast(enabled=use_autocast, dtype=args.dtype):
output = model(**batch)
loss = output.loss
eval_losses.append(loss.item())
eval_loss = torch.tensor(eval_losses).mean().item()
return eval_loss
# ======================== run train loop ======================
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
device = args.device
total_train_steps = (len(dl_train) * args.num_epochs) // args.gradient_accumulation_steps
step = 0
train_losses = {}
eval_losses = {}
loss_accum = 0.0
# do eval before training
eval_loss = eval()
eval_losses[step] = eval_loss
wandb.log({'eval/loss': eval_loss}, step=step)
# do training
with tqdm(total=total_train_steps, desc="steps", position=0) as pbar:
for epoch in range(args.num_epochs):
for i, batch in enumerate(dl_train):
batch = {k: v.to(device) for k, v in batch.items()}
with torch.cuda.amp.autocast(enabled=use_autocast, dtype=args.dtype):
output = model(**batch)
loss = output.loss
loss = loss / args.gradient_accumulation_steps
loss.backward()
loss_accum += loss.item()
if (i + 1) % args.gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
step += 1
pbar.update(1)
pbar.write(f"step: {step:05d}\ttrain_loss: {loss_accum}")
wandb.log({'train/loss': loss_accum}, step=step)
train_losses[step] = loss_accum
loss_accum = 0.0
if step % args.eval_steps == 0:
eval_loss = eval()
eval_losses[step] = eval_loss
pbar.write(f"step: {step:05d}\teval_loss: {eval_loss}")
wandb.log({'eval/loss': eval_loss}, step=step)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment