Skip to content

Instantly share code, notes, and snippets.

@norabelrose
Created November 8, 2023 07:04
Show Gist options
  • Save norabelrose/bd63da49a301f367c365962cda9e385b to your computer and use it in GitHub Desktop.
Save norabelrose/bd63da49a301f367c365962cda9e385b to your computer and use it in GitHub Desktop.
Training quirky models with DPO
from argparse import ArgumentParser
from datasets import load_dataset
from peft import LoraConfig
from trl import DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("name", type=str)
parser.add_argument(
"--dataset", type=str, default="atmallen/qm_mixture_1.0e_0.5p_finetuning",
)
parser.add_argument(
"--lora-modules",
type=str,
nargs="+",
default=["gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj"],
)
parser.add_argument(
"--lora-rank", type=int, default=8,
)
parser.add_argument("--model", type=str, default="mistralai/Mistral-7B-v0.1")
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model)
tokenizer.pad_token_id = tokenizer.eos_token_id
ds = load_dataset(
args.dataset,
).rename_column(
'statement', 'prompt'
).map(
lambda x: {
'chosen': x['choices'][x['label']],
'rejected': x['choices'][1 - x['label']],
},
remove_columns=['choices', 'label', 'true_label']
).shuffle(42)
trainer = DPOTrainer(
model=AutoModelForCausalLM.from_pretrained(args.model, torch_dtype="auto"),
args=TrainingArguments(
f"checkpoints/{args.name}",
fp16=True,
gradient_accumulation_steps=4,
logging_steps=50,
num_train_epochs=1,
per_device_train_batch_size=5,
remove_unused_columns=False,
run_name=args.name,
warmup_steps=500,
weight_decay=0.1,
),
max_length=512,
max_prompt_length=128,
peft_config=(
LoraConfig( # type: ignore
r=args.lora_rank, target_modules=args.lora_modules
)
if args.lora_rank > 0 else None
),
train_dataset=ds["train"],
eval_dataset=ds["validation"],
tokenizer=tokenizer,
)
trainer.train()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment