Skip to content

Instantly share code, notes, and snippets.

@ericflo
Created December 27, 2023 05:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ericflo/70d569d7e2db2a7f6cac924d535b1c99 to your computer and use it in GitHub Desktop.
Save ericflo/70d569d7e2db2a7f6cac924d535b1c99 to your computer and use it in GitHub Desktop.
from datasets import load_dataset
from trl import SFTTrainer
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TrainingArguments,
HfArgumentParser,
)
from peft import LoraConfig
import torch
def make_formatting_func(template_tokenizer_name):
tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_name)
def inner(example):
return tokenizer.apply_chat_template(example["messages"], tokenize=False)
return inner
def main(
template_tokenizer_name="teknium/OpenHermes-2.5-Mistral-7B",
model_name="mistralai/Mistral-7B-v0.1",
dataset_name="ericflo/unnaturalhermes-reflections-100k",
context_length=32768,
):
parser = HfArgumentParser(TrainingArguments)
training_args = parser.parse_args_into_dataclasses()[0]
full_dataset = load_dataset(dataset_name, split="train")
filtered_dataset = full_dataset.filter(
lambda row: row["metadata"]["prompt_version"] == 3
and "ixtral" in row["metadata"]["model"]
)
dataset = filtered_dataset.train_test_split(test_size=500).with_format("torch")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
)
formatting_func = make_formatting_func(template_tokenizer_name)
peft_config = LoraConfig(
r=64,
lora_alpha=256,
lora_dropout=0.05,
target_modules=["gate_proj", "down_proj", "up_proj"],
bias="none",
task_type="CAUSAL_LM",
)
trainer = SFTTrainer(
model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
formatting_func=formatting_func,
max_seq_length=context_length,
peft_config=peft_config,
packing=True,
)
trainer.train()
trainer.save_model("final")
if __name__ == "__main__":
"""
python train.py \
--output_dir mistral-7b-reflect \
--report_to wandb \
--bf16 True \
--gradient_checkpointing True \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--logging_steps 1 \
--do_eval True \
--evaluation_strategy steps \
--eval_steps 20
"""
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment