Skip to content

Instantly share code, notes, and snippets.

@alfredplpl
Created February 24, 2024 04:23
Show Gist options
  • Star 12 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alfredplpl/e20cad036c151f38645a1abc87f56a2f to your computer and use it in GitHub Desktop.
Save alfredplpl/e20cad036c151f38645a1abc87f56a2f to your computer and use it in GitHub Desktop.
Gemma初心者ファインチューニングコードです。HFの設定などはよしなにやってください。
# Reference #1: https://note.com/npaka/n/nc55e44e407ff
# Reference #2: https://huggingface.co/blog/gemma-peft
# Licence: MIT
from peft import LoraConfig
lora_config = LoraConfig(
r=8,
target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
task_type="CAUSAL_LM",
)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
model_id = "google/gemma-2b-it"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
import os
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ['HF_TOKEN'])
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0}, token=os.environ['HF_TOKEN'])
from datasets import load_dataset
# データセットの読み込み
dataset = load_dataset("bbz662bbz/databricks-dolly-15k-ja-gozarinnemon", split="train")
dataset = dataset.filter(lambda example: example["category"] == "open_qa")
# プロンプトの生成
def generate_prompt(example):
return """<bos><start_of_turn>user
{}<end_of_turn>
<start_of_turn>model
{}<eos>""".format(example["instruction"], example["output"])
# textカラムの追加
def add_text(example):
example["text"] = generate_prompt(example)
return example
dataset = dataset.map(add_text)
dataset = dataset.remove_columns(["input", "category", "output", "index", "instruction"])
# データセットの分割
train_test_split = dataset.train_test_split(test_size=0.1)
train_dataset = train_test_split["train"]
eval_dataset = train_test_split["test"]
import transformers
from trl import SFTTrainer
trainer = SFTTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=transformers.TrainingArguments(
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
warmup_steps=10,
max_steps=1000,
learning_rate=2e-4,
fp16=True,
logging_steps=50,
output_dir="outputs",
optim="paged_adamw_8bit"
),
peft_config=lora_config,
dataset_text_field="text"
)
trainer.train()
# trainer.save_model("/path/to/model")
# プロンプトの準備
prompt="""<start_of_turn>user
猫と犬、どっちが好き?<end_of_turn>
<start_of_turn>model
"""
# 推論の実行
input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**input_ids,
max_new_tokens=128,
do_sample=True,
top_p=0.95,
temperature=0.2,
repetition_penalty=1.1,
)
print(tokenizer.decode(outputs[0]))
# <bos><start_of_turn>user
# 猫と犬、どっちが好き?<end_of_turn>
# <start_of_turn>model
# 我、りんえもんは思う。 猫と犬はどちらも素晴らしい動物でござる。猫は、犬よりもより静かな動物で、犬よりもより行動的で、猫よりもより小さな動物でござる。犬は、猫よりもより行動的で、猫よりもより行動的で、犬よりもより行動的で、犬よりもより行動的で、犬よりもより行動的で、犬よりもより行動的で、犬よりもより行動的で、犬よりもより行動的で、犬よりもより行動的で、犬よりもより行動的で、犬よりもより行動的で、犬よりもより行動的で、犬よりもより行動的で、犬よりもより
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment