Skip to content

Instantly share code, notes, and snippets.

@luistung
Created April 28, 2024 09:45
Show Gist options
  • Save luistung/5a6a625d6600cd7176d82b7551506d1c to your computer and use it in GitHub Desktop.
Save luistung/5a6a625d6600cd7176d82b7551506d1c to your computer and use it in GitHub Desktop.
continue pretrain example using hugging face
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from datasets import Dataset
# 选择模型,这里可以替换为任何 transformers 支持的模型,如 "bert-base-uncased", "gpt2" 等
model_name = "gpt2"
device = torch.device("cpu")
# 加载模型和分词器
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
# 定义一些样本数据
texts = [
"Hello, my name is Alice and I like to teach math.",
"Hello, my name is Bob and I enjoy writing code.",
"Hi there, my name is Carol and I love reading books."
]
# 编码文本
encodings = tokenizer(texts, padding="max_length", truncation=True, max_length=512, return_tensors="pt")
#print(encodings)
# 将编码数据转换为 Dataset
dataset = Dataset.from_dict(encodings)
# 数据整理,为模型训练准备数据
def data_collator(features):
batch = {key: torch.tensor([f[key] for f in features]) for key in features[0]}
batch["labels"] = torch.full(batch["input_ids"].shape, -100)
batch["labels"] = batch["input_ids"].clone()
return batch
# 设置训练参数
training_args = TrainingArguments(
output_dir="./model_output",
overwrite_output_dir=True,
num_train_epochs=10,
per_device_train_batch_size=3,
logging_steps=1,
save_strategy="no"
)
# 初始化训练器
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=dataset
)
# 开始训练
trainer.train()
# 准备要生成文本的输入语句
input_text = "Hello, my name is Alice"
input_ids = tokenizer.encode(input_text, return_tensors="pt")
# 生成文本
# 这里的 `max_length` 和 `num_return_sequences` 可以根据需要调整
model.eval()
generated_ids = model.generate(input_ids, max_length=50, num_return_sequences=1, temperature=0.0)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print(generated_text)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment