Skip to content

Instantly share code, notes, and snippets.

@thistleknot
Last active April 27, 2024 02:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thistleknot/923d0a77d4c3e4a9eef9d2619fadb078 to your computer and use it in GitHub Desktop.
Save thistleknot/923d0a77d4c3e4a9eef9d2619fadb078 to your computer and use it in GitHub Desktop.
pyreft loreft continued pretraining using completion
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from pyreft import ReftConfig, ReftTrainerForCausalLM, get_reft_model, ReftSupervisedDataset, ReftDataCollator, LoreftIntervention
import torch
import pyreft
from datasets import load_dataset
# Load the SQuAD v2 dataset from Hugging Face datasets
squad_v2 = load_dataset("squad_v2", split="train[:10%]")
# Process data to obtain questions as inputs and answers as outputs
inputs = ["Context:\n\n" + q['context'] +"\n\n" + "Question:\n\n" + q['question'] for q in squad_v2]
outputs = ["Answer:\n\n" + a['answers']['text'][0] if len(a['answers']['text']) > 0 else "" for a in squad_v2]
# Step 1: Load the pretrained model and tokenizer
model_name_or_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model = transformers.AutoModelForCausalLM.from_pretrained(
model_name_or_path, torch_dtype=torch.bfloat16, device_map='cuda')
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)
#model.half()
# Step 2: Define interventions for ReFT across all layers
layers = range(model.config.num_hidden_layers)
representations = [{
"layer": l,
"component": "block_output",
"low_rank_dimension": 4,
"intervention": pyreft.LoreftIntervention(
embed_dim=model.config.hidden_size, low_rank_dimension=4)
} for l in layers]
# Step 3: ReFT configuration and model setup
reft_config = pyreft.ReftConfig(representations=representations)
reft_model = pyreft.get_reft_model(model, reft_config, set_device='cuda')
reft_model.print_trainable_parameters()
# Step 4: Prepare Data for ReFT using a method that targets the last position of input prompts
data_module = pyreft.make_last_position_supervised_data_module(
tokenizer, model, inputs, outputs, num_interventions=len(layers))
# Step 5: Train the Model
training_args = transformers.TrainingArguments(
num_train_epochs=1, # Adjust as needed
output_dir="./reft_squad_model",
per_device_train_batch_size=8,
learning_rate=2e-5,
logging_steps=10,
report_to=[] # Disable all integrations, including TensorBoard
)
trainer = pyreft.ReftTrainerForCausalLM(
model=reft_model, tokenizer=tokenizer, args=training_args, **data_module)
trainer.train()
# Step 6: Save and Share the Model
reft_model.set_device("cpu") # Move the model to CPU before saving
reft_model.save(
save_directory="./reft_squad_model",
save_to_hf_hub=False,
hf_repo_name="your_reft_squad_model"
)
print("Model training and saving completed.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment