Skip to content

Instantly share code, notes, and snippets.

@luistung
Last active April 28, 2024 09:53
Show Gist options
  • Save luistung/2434e339cca53cff82ef97a3bb243fe4 to your computer and use it in GitHub Desktop.
Save luistung/2434e339cca53cff82ef97a3bb243fe4 to your computer and use it in GitHub Desktop.
finetune llm example
from transformers import AutoTokenizer
from datasets import Dataset
import torch
import pandas as pd
import numpy as np
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token # 设置填充符号
qa_pairs = [
{"question": "What is the capital of France?", "answer": "Paris"},
{"question": "What is the capital of Germany?", "answer": "Berlin"},
{"question": "How are you?", "answer": "I am fine"}
]
def convert_qa_to_input_and_labels(pair):
question = pair["question"]
answer = pair["answer"]
prompt_text = f"Question: {question} Answer:"
target_text = answer
# 对输入文本和目标文本进行编码
prompt_ids = tokenizer.encode(prompt_text, return_tensors="pt").squeeze(0)
target_ids = tokenizer.encode(target_text, return_tensors="pt").squeeze(0)
target_ids = torch.concat([target_ids, torch.tensor([tokenizer.eos_token_id])])
print(prompt_ids.shape, target_ids.shape)
# 设置问题部分的标签为-100
input_ids = torch.concat([prompt_ids, target_ids], 0)
label_ids = torch.concat([torch.full(prompt_ids.shape, -100), target_ids], 0)
# 转换为 NumPy 数组
input_ids_np = input_ids.numpy()
label_ids_np = label_ids.numpy()
return {"input_ids": input_ids_np, "labels": label_ids_np}
# 处理所有问答对并创建一个 Pandas DataFrame
encoded_data = [convert_qa_to_input_and_labels(pair) for pair in qa_pairs]
df = pd.DataFrame(encoded_data)
# 创建 Dataset 对象
dataset = Dataset.from_pandas(df)
from transformers import AutoModelForCausalLM, Trainer, TrainingArguments
model = AutoModelForCausalLM.from_pretrained("gpt2")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
training_args = TrainingArguments(
output_dir="./qa_model_output",
overwrite_output_dir=True,
num_train_epochs=10,
per_device_train_batch_size=3,
logging_steps=1,
save_strategy="no"
)
from torch.nn.utils.rnn import pad_sequence
def data_collator(batch):
# batch 中的每个元素都是从 Dataset 中取出的一个样本,即一个字典
input_ids = [torch.tensor(f['input_ids'], dtype=torch.long) for f in batch]
labels = [torch.tensor(f['labels'], dtype=torch.long) for f in batch]
# 使用 pad_sequence 对 input_ids 和 labels 进行填充
# 这里需要设置 batch_first=True,因为我们希望每个张量的第一个维度是批次大小
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
labels = pad_sequence(labels, batch_first=True, padding_value=-100) # 使用 -100 填充 labels 以避免计算这部分的损失
# 创建 attention_mask 来指示哪些位置有真实的数据,哪些位置是填充的
attention_mask = torch.where(input_ids != tokenizer.pad_token_id, 1, 0)
return {'input_ids': input_ids, 'labels': labels, 'attention_mask': attention_mask}
# 使用简单的 data collator,因为我们已经准备好了正确的标签
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=data_collator
)
trainer.train()
model.save_pretrained("gpt2-ft")
model = AutoModelForCausalLM.from_pretrained("gpt2-ft")
model.eval()
input_question = "How are you?"
input_text = f"Question: {input_question} Answer:"
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
generated_ids = model.generate(input_ids, max_length=50, num_return_sequences=1, temperature=0.0, eos_token_id=tokenizer.eos_token_id)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(generated_text)
def log_prob(model, input_, output_, device):
model.eval()
input_ids = tokenizer.encode(input_, return_tensors="pt").to(device)
output_ids = tokenizer.encode(output_, return_tensors="pt").to(device)
input_output_ids = torch.concat([input_ids, output_ids], -1)
logits = model(input_output_ids).logits
shift_logits = logits[:, input_ids.shape[1]-1:input_ids.shape[1]-1+output_ids.shape[1]:]
log_probs = torch.nn.functional.log_softmax(shift_logits, dim=-1)
log_probs = log_probs.gather(dim=-1, index=output_ids.unsqueeze(-1)).squeeze(-1)
total_log_prob = log_probs.sum()
return total_log_prob.item()
input_ = "How are you?"
output_ = "I am fine"
print(log_prob(model, input_, output_, device))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment