Last active
April 27, 2024 02:42
-
-
Save thistleknot/923d0a77d4c3e4a9eef9d2619fadb078 to your computer and use it in GitHub Desktop.
pyreft loreft continued pretraining using completion
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import transformers | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments | |
from pyreft import ReftConfig, ReftTrainerForCausalLM, get_reft_model, ReftSupervisedDataset, ReftDataCollator, LoreftIntervention | |
import torch | |
import pyreft | |
from datasets import load_dataset | |
# Load the SQuAD v2 dataset from Hugging Face datasets | |
squad_v2 = load_dataset("squad_v2", split="train[:10%]") | |
# Process data to obtain questions as inputs and answers as outputs | |
inputs = ["Context:\n\n" + q['context'] +"\n\n" + "Question:\n\n" + q['question'] for q in squad_v2] | |
outputs = ["Answer:\n\n" + a['answers']['text'][0] if len(a['answers']['text']) > 0 else "" for a in squad_v2] | |
# Step 1: Load the pretrained model and tokenizer | |
model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
model = transformers.AutoModelForCausalLM.from_pretrained( | |
model_name_or_path, torch_dtype=torch.bfloat16, device_map='cuda') | |
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path) | |
#model.half() | |
# Step 2: Define interventions for ReFT across all layers | |
layers = range(model.config.num_hidden_layers) | |
representations = [{ | |
"layer": l, | |
"component": "block_output", | |
"low_rank_dimension": 4, | |
"intervention": pyreft.LoreftIntervention( | |
embed_dim=model.config.hidden_size, low_rank_dimension=4) | |
} for l in layers] | |
# Step 3: ReFT configuration and model setup | |
reft_config = pyreft.ReftConfig(representations=representations) | |
reft_model = pyreft.get_reft_model(model, reft_config, set_device='cuda') | |
reft_model.print_trainable_parameters() | |
# Step 4: Prepare Data for ReFT using a method that targets the last position of input prompts | |
data_module = pyreft.make_last_position_supervised_data_module( | |
tokenizer, model, inputs, outputs, num_interventions=len(layers)) | |
# Step 5: Train the Model | |
training_args = transformers.TrainingArguments( | |
num_train_epochs=1, # Adjust as needed | |
output_dir="./reft_squad_model", | |
per_device_train_batch_size=8, | |
learning_rate=2e-5, | |
logging_steps=10, | |
report_to=[] # Disable all integrations, including TensorBoard | |
) | |
trainer = pyreft.ReftTrainerForCausalLM( | |
model=reft_model, tokenizer=tokenizer, args=training_args, **data_module) | |
trainer.train() | |
# Step 6: Save and Share the Model | |
reft_model.set_device("cpu") # Move the model to CPU before saving | |
reft_model.save( | |
save_directory="./reft_squad_model", | |
save_to_hf_hub=False, | |
hf_repo_name="your_reft_squad_model" | |
) | |
print("Model training and saving completed.") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment