Last active
April 28, 2024 09:53
-
-
Save luistung/2434e339cca53cff82ef97a3bb243fe4 to your computer and use it in GitHub Desktop.
finetune llm example
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import AutoTokenizer | |
from datasets import Dataset | |
import torch | |
import pandas as pd | |
import numpy as np | |
tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
tokenizer.pad_token = tokenizer.eos_token # 设置填充符号 | |
qa_pairs = [ | |
{"question": "What is the capital of France?", "answer": "Paris"}, | |
{"question": "What is the capital of Germany?", "answer": "Berlin"}, | |
{"question": "How are you?", "answer": "I am fine"} | |
] | |
def convert_qa_to_input_and_labels(pair): | |
question = pair["question"] | |
answer = pair["answer"] | |
prompt_text = f"Question: {question} Answer:" | |
target_text = answer | |
# 对输入文本和目标文本进行编码 | |
prompt_ids = tokenizer.encode(prompt_text, return_tensors="pt").squeeze(0) | |
target_ids = tokenizer.encode(target_text, return_tensors="pt").squeeze(0) | |
target_ids = torch.concat([target_ids, torch.tensor([tokenizer.eos_token_id])]) | |
print(prompt_ids.shape, target_ids.shape) | |
# 设置问题部分的标签为-100 | |
input_ids = torch.concat([prompt_ids, target_ids], 0) | |
label_ids = torch.concat([torch.full(prompt_ids.shape, -100), target_ids], 0) | |
# 转换为 NumPy 数组 | |
input_ids_np = input_ids.numpy() | |
label_ids_np = label_ids.numpy() | |
return {"input_ids": input_ids_np, "labels": label_ids_np} | |
# 处理所有问答对并创建一个 Pandas DataFrame | |
encoded_data = [convert_qa_to_input_and_labels(pair) for pair in qa_pairs] | |
df = pd.DataFrame(encoded_data) | |
# 创建 Dataset 对象 | |
dataset = Dataset.from_pandas(df) | |
from transformers import AutoModelForCausalLM, Trainer, TrainingArguments | |
model = AutoModelForCausalLM.from_pretrained("gpt2") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
training_args = TrainingArguments( | |
output_dir="./qa_model_output", | |
overwrite_output_dir=True, | |
num_train_epochs=10, | |
per_device_train_batch_size=3, | |
logging_steps=1, | |
save_strategy="no" | |
) | |
from torch.nn.utils.rnn import pad_sequence | |
def data_collator(batch): | |
# batch 中的每个元素都是从 Dataset 中取出的一个样本,即一个字典 | |
input_ids = [torch.tensor(f['input_ids'], dtype=torch.long) for f in batch] | |
labels = [torch.tensor(f['labels'], dtype=torch.long) for f in batch] | |
# 使用 pad_sequence 对 input_ids 和 labels 进行填充 | |
# 这里需要设置 batch_first=True,因为我们希望每个张量的第一个维度是批次大小 | |
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) | |
labels = pad_sequence(labels, batch_first=True, padding_value=-100) # 使用 -100 填充 labels 以避免计算这部分的损失 | |
# 创建 attention_mask 来指示哪些位置有真实的数据,哪些位置是填充的 | |
attention_mask = torch.where(input_ids != tokenizer.pad_token_id, 1, 0) | |
return {'input_ids': input_ids, 'labels': labels, 'attention_mask': attention_mask} | |
# 使用简单的 data collator,因为我们已经准备好了正确的标签 | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
train_dataset=dataset, | |
data_collator=data_collator | |
) | |
trainer.train() | |
model.save_pretrained("gpt2-ft") | |
model = AutoModelForCausalLM.from_pretrained("gpt2-ft") | |
model.eval() | |
input_question = "How are you?" | |
input_text = f"Question: {input_question} Answer:" | |
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device) | |
generated_ids = model.generate(input_ids, max_length=50, num_return_sequences=1, temperature=0.0, eos_token_id=tokenizer.eos_token_id) | |
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
print(generated_text) | |
def log_prob(model, input_, output_, device): | |
model.eval() | |
input_ids = tokenizer.encode(input_, return_tensors="pt").to(device) | |
output_ids = tokenizer.encode(output_, return_tensors="pt").to(device) | |
input_output_ids = torch.concat([input_ids, output_ids], -1) | |
logits = model(input_output_ids).logits | |
shift_logits = logits[:, input_ids.shape[1]-1:input_ids.shape[1]-1+output_ids.shape[1]:] | |
log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1) | |
log_probs = log_probs.gather(dim=-1, index=output_ids.unsqueeze(-1)).squeeze(-1) | |
total_log_prob = log_probs.sum() | |
return total_log_prob.item() | |
input_ = "How are you?" | |
output_ = "I am fine" | |
print(log_prob(model, input_, output_, device)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment