Created
November 15, 2023 18:44
-
-
Save perryism/35ab1ea36d730e72e964b4f711c3d13f to your computer and use it in GitHub Desktop.
Lora Low rank adaptation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments, pipeline, logging, TextStreamer | |
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model | |
import os,torch, platform, warnings | |
from datasets import load_dataset | |
from trl import SFTTrainer | |
class Lora: | |
@classmethod | |
def prepare(cls, model_name): | |
""" | |
Prepare a Lora model | |
""" | |
# Load base model and tokenizer | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit= True, | |
bnb_4bit_quant_type= "nf4", | |
bnb_4bit_compute_dtype= torch.bfloat16, # good spec use bfloat16, otherwise, float16 | |
bnb_4bit_use_double_quant= False, | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
quantization_config=bnb_config, | |
device_map={"": 0} | |
) | |
model = prepare_model_for_kbit_training(model) | |
model.config.use_cache = False # silence the warnings. Please re-enable for inference! | |
model.config.pretraining_tp = 1 | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.add_eos_token = True | |
tokenizer.add_bos_token, tokenizer.add_eos_token | |
return cls(model, tokenizer) | |
@classmethod | |
def load(cls, model_name, lora_path): | |
""" | |
Load existing trained lora model | |
""" | |
base_model = AutoModelForCausalLM.from_pretrained( | |
model_name, low_cpu_mem_usage=True, | |
return_dict=True,torch_dtype=torch.bfloat16, | |
device_map= {"": 0}) | |
model = PeftModel.from_pretrained(base_model, lora_path) | |
model = model.merge_and_unload() | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
return cls(model, tokenizer) | |
def __init__(self, model, tokenizer): | |
self.model = model | |
self.tokenizer = tokenizer | |
def train(self, dataset, epochs=30, report_to=None): | |
""" | |
trainer.model.save_pretrained(lora_path) | |
""" | |
training_arguments = TrainingArguments( | |
output_dir= "./results", | |
num_train_epochs=epochs, | |
per_device_train_batch_size= 6, | |
gradient_accumulation_steps= 2, | |
optim = "paged_adamw_8bit", | |
save_steps= 1000, | |
logging_steps= int(epochs/5), | |
learning_rate= 2e-4, | |
weight_decay= 0.001, | |
fp16= False, | |
bf16= False, | |
max_grad_norm= 0.3, | |
max_steps= -1, | |
warmup_ratio= 0.3, | |
group_by_length= True, | |
lr_scheduler_type= "linear", | |
report_to=report_to | |
) | |
peft_config = LoraConfig( | |
lora_alpha= 8, | |
lora_dropout= 0.1, | |
r= 16, | |
bias="none", | |
task_type="CAUSAL_LM", | |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj", "up_proj"] | |
) | |
# Setting sft parameters | |
trainer = SFTTrainer( | |
model=self.model, | |
train_dataset=dataset, | |
peft_config=peft_config, | |
max_seq_length= None, | |
dataset_text_field="text", | |
tokenizer=self.tokenizer, | |
args=training_arguments, | |
packing= False, | |
) | |
trainer.train() | |
return trainer |
Author
perryism
commented
Jun 4, 2024
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment