Skip to content

Instantly share code, notes, and snippets.

@152334H
Last active December 27, 2023 23:24
Show Gist options
  • Save 152334H/4847f3a8cca12894877e6b30698b0b64 to your computer and use it in GitHub Desktop.
Save 152334H/4847f3a8cca12894877e6b30698b0b64 to your computer and use it in GitHub Desktop.
neuralhermes with unsloth lora. requires base model to be hacked from mistral -> llama
import torch
from unsloth import FastLlamaModel
from transformers import TrainingArguments
from datasets import load_dataset
from trl import DPOTrainer
model_name = "teknium/OpenHermes-2.5-Mistral-7B"
model_name = "./OpenHermes-2.5-Mistral-7B"
new_model = "NeuralHermes-2.5-Mistral-7B"
def chatml_format(example):
# Format system
if len(example['system']) > 0:
message = {"role": "system", "content": example['system']}
system = tokenizer.apply_chat_template([message], tokenize=False)
else: system = ""
# Format instruction
message = {"role": "user", "content": example['question']}
prompt = tokenizer.apply_chat_template([message], tokenize=False, add_generation_prompt=True)
# Format chosen/rejected answer
chosen = example['chosen'] + "<|im_end|>\n"
rejected = example['rejected'] + "<|im_end|>\n"
return {
"prompt": system + prompt,
"chosen": chosen,
"rejected": rejected,
}
# Load dataset
dataset = load_dataset("Intel/orca_dpo_pairs")['train']
# Save columns
original_columns = dataset.column_names
max_seq_length = 2048
dtype = torch.bfloat16 # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
device_map={'':0} # cannot use the default 'sequential' or ref model idx != model idx
model, tokenizer = FastLlamaModel.from_pretrained(
model_name = model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
device_map=device_map,
# token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
# from IPython import embed; embed()
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"
# Format dataset
dataset = dataset.map(
chatml_format,
remove_columns=original_columns
)
print(dataset[1]) # Print sample
model = FastLlamaModel.get_peft_model(
model,
r = 16,
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",],
lora_alpha = 16,
lora_dropout = 0, # Currently only supports dropout = 0
# lora_dropout=0.05,
bias = "none", # Currently only supports bias = "none"
# task_type="CAUSAL_LM",
use_gradient_checkpointing = True,
random_state = 3407,
max_seq_length = max_seq_length,
)
ref_model, _ = FastLlamaModel.from_pretrained(
model_name = model_name,
max_seq_length = max_seq_length,
dtype = dtype,
load_in_4bit = load_in_4bit,
device_map=device_map,
)
# Training arguments
training_args = TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
gradient_checkpointing=True,
learning_rate=5e-5,
lr_scheduler_type="cosine",
max_steps=200,
save_strategy="no",
logging_steps=1,
output_dir=new_model,
optim="paged_adamw_8bit",
warmup_steps=100,
bf16=True,
report_to="wandb",
)
# Create DPO trainer
dpo_trainer = DPOTrainer(
model,
ref_model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
# peft_config=peft_config,
beta=0.1,
max_prompt_length=1024,
max_length=1536,
)
# Fine-tune model with DPO
dpo_trainer.train()
# Save artifacts
dpo_trainer.model.save_pretrained("final_checkpoint")
tokenizer.save_pretrained("final_checkpoint")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment