Created
February 15, 2024 00:11
-
-
Save yilenpan/704346c4e729bd16df5fb960bd52a09b to your computer and use it in GitHub Desktop.
Lora Training
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
! pip install -q "datasets==2.15.0" | |
from datasets import load_dataset | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
import torch | |
# LOAD DATASET | |
# The dair-ai/emotion dataset has three splits | |
splits = ["train", "test", "validation"] | |
data_splits = {} | |
for split, ds in zip(splits, load_dataset("dair-ai/emotion", split=splits)): | |
data_splits[split] = ds | |
# INIT TOKENIZER | |
tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
tokenizer.pad_token = tokenizer.eos_token | |
def preprocess_function(x): | |
return tokenizer(x['text'], truncation=True, return_tensors="pt", padding=True) | |
def format_dataset(ds): | |
ds = ds.remove_columns('text') # Was running into | |
ds = ds.rename_column('label', 'labels') | |
ds.set_format('torch', columns=['labels', 'input_ids', 'attention_mask']) | |
return ds | |
# PROCESS DATA SPLITS | |
test_split = data_splits["test"].map(preprocess_function) | |
train_split = data_splits["train"].map(preprocess_function) | |
val_split = data_splits["validation"].map(preprocess_function) | |
test_split = format_dataset(test_split) | |
train_split = format_dataset(train_split) | |
val_split = format_dataset(val_split) | |
# Setting up model for sequence classification | |
label2id={'sadness': 0, 'joy': 1, 'love': 2, 'anger': 3, 'fear': 4, 'surprise': 5} | |
id2label={0: 'sadness', 1: 'joy', 2: 'love', 3: 'anger', 4: 'fear', 5: 'surprise'} | |
model = AutoModelForSequenceClassification.from_pretrained( | |
"gpt2", | |
num_labels=6, | |
label2id=label2id, | |
id2label=id2label, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
# Testing out the model no training | |
def get_label(output): | |
logits = outputs.logits | |
probabilities = torch.softmax(logits, dim=-1) | |
return torch.argmax(probabilities, dim=-1).item() | |
test_row = test_split[0] | |
outputs = model(**test_row) | |
predicted_label = get_label(outputs) | |
print(f"(predicted label: {predicted_label}, actual label: {test_row['labels']})") | |
# (predicted label: 4, actual label: 0) | |
# Allow params to be adjusted | |
for param in model.parameters(): | |
param.requires_grad = True | |
# CREATE PEFT CONFIG | |
from peft import LoraConfig, get_peft_model | |
config = LoraConfig( | |
r=16, # Controls the number of parameters updated during training | |
lora_alpha=16, # scales the weights in the lora, higher means stronger | |
fan_in_fan_out=True, # set this for gpt2 models | |
bias="lora_only", # updates bias for just the lora | |
lora_dropout=0.01, # dropout rate to ensure that we don't overfit | |
target_modules=['c_attn', 'c_proj'], # make sure we target the correct layers | |
) | |
lora_model = get_peft_model(model, config) | |
print(lora_model.print_trainable_parameters()) | |
# CREATE TRAINER CLASS | |
from transformers import TrainingArguments, Trainer, DataCollatorWithPadding | |
import numpy as np | |
def compute_metrics(eval_pred): | |
predictions, labels = eval_pred | |
predictions = np.argmax(predictions, axis=1) | |
return {"accuracy": (predictions == labels).mean()} | |
training_args = TrainingArguments( | |
output_dir="./data/gpt2-lora", | |
overwrite_output_dir=True, | |
evaluation_strategy="epoch", | |
num_train_epochs=2, | |
per_device_train_batch_size=32, # Kick this up if we can | |
per_device_eval_batch_size=32, # Kick this up if we can | |
learning_rate=2e-5, | |
save_strategy="epoch", | |
) | |
trainer = Trainer( | |
model=lora_model, | |
args=training_args, | |
train_dataset=train_split, | |
eval_dataset=test_split, | |
tokenizer=tokenizer, | |
data_collator=DataCollatorWithPadding(tokenizer=tokenizer), | |
compute_metrics=compute_metrics, | |
) | |
### EVALUATE MODEL WITHOUT TRAINING | |
trainer.evaluate() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment