-
-
Save kishida/4fa70a66a7bc258a153ffbf178c04198 to your computer and use it in GitHub Desktop.
gozaru_lora.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"gpuType": "T4", | |
"authorship_tag": "ABX9TyOgZoWR2mqdfO7E2ZsHtheP", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/kishida/4fa70a66a7bc258a153ffbf178c04198/gozaru_lora.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# ライブラリインストール" | |
], | |
"metadata": { | |
"id": "V7J438GkemOG" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "MWBfO08tSMZV" | |
}, | |
"outputs": [], | |
"source": [ | |
"!python3 -m pip install -U pip\n", | |
"!python3 -m pip install transformers accelerate datasets peft sentencepiece" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# 準備" | |
], | |
"metadata": { | |
"id": "OXAqv0UOeyT4" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import torch\n", | |
"import datasets\n", | |
"from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments\n", | |
"from peft import get_peft_model, LoraConfig, TaskType, PeftModel, PeftConfig\n", | |
"\n", | |
"model_name = \"line-corporation/japanese-large-lm-1.7b\"\n", | |
"peft_model_name = \"peft_model\"\n", | |
"\n", | |
"prompt_template_cqa = \"\"\"ユーザー: 次の情報を元に質問に答えてください。{input}\n", | |
"システム: わかりました。\n", | |
"ユーザー: {instruction}\n", | |
"システム: \"\"\"\n", | |
"prompt_template_oqa = \"\"\"ユーザー: {instruction}\n", | |
"システム: \"\"\"\n", | |
"\n", | |
"def encode(sample):\n", | |
" if (sample[\"input\"]):\n", | |
" prompt = prompt_template_cqa.format(instruction=sample[\"instruction\"], input=sample[\"input\"])\n", | |
" else:\n", | |
" prompt = prompt_template_oqa.format(instruction=sample[\"instruction\"])\n", | |
" target = sample[\"output\"] + tokenizer.eos_token\n", | |
" input_ids_prompt, input_ids_target = tokenizer([prompt, target]).input_ids\n", | |
" input_ids = input_ids_prompt + input_ids_target\n", | |
" labels = input_ids.copy()\n", | |
" labels[:len(input_ids_prompt)] = [-100] * len(input_ids_prompt)\n", | |
" return {\"input_ids\": input_ids, \"labels\": labels}\n", | |
"\n", | |
"def get_collator(tokenizer, max_length):\n", | |
" def collator(batch):\n", | |
" batch = [{ key: value[:max_length] for key, value in sample.items() } for sample in batch ]\n", | |
" batch = tokenizer.pad(batch, padding=True)\n", | |
" batch[\"labels\"] = [ e + [-100] * (len(batch[\"input_ids\"][0]) - len(e)) for e in batch[\"labels\"] ]\n", | |
" batch = { key: torch.tensor(value) for key, value in batch.items() }\n", | |
" return batch\n", | |
"\n", | |
" return collator" | |
], | |
"metadata": { | |
"id": "lvLbrc8bSVlx" | |
}, | |
"execution_count": 2, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# データセットとモデルの準備" | |
], | |
"metadata": { | |
"id": "kHJ8qmXne2P-" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# prepare dataset\n", | |
"tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)\n", | |
"\n", | |
"# dataset_name = \"kunishou/databricks-dolly-15k-ja\"\n", | |
"dataset_name = \"bbz662bbz/databricks-dolly-15k-ja-gozaru\"\n", | |
"dataset = datasets.load_dataset(dataset_name)\n", | |
"dataset = dataset.map(encode)\n", | |
"dataset = dataset[\"train\"].train_test_split(0.2)\n", | |
"train_dataset = dataset[\"train\"]\n", | |
"val_dataset = dataset[\"test\"]\n", | |
"\n", | |
"# load model\n", | |
"base_model = AutoModelForCausalLM.from_pretrained(model_name, device_map={\"\": 0}, torch_dtype=torch.float16)\n", | |
"\n", | |
"peft_config = LoraConfig(\n", | |
" task_type=TaskType.CAUSAL_LM,\n", | |
" inference_mode=False,\n", | |
" target_modules=[\"c_attn\"],\n", | |
" r=16,\n", | |
" lora_alpha=32,\n", | |
" lora_dropout=0.05\n", | |
")\n", | |
"\n", | |
"model = get_peft_model(base_model, peft_config)\n", | |
"model.print_trainable_parameters()" | |
], | |
"metadata": { | |
"id": "709R1gNZSyX9" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# LoRAトレーニング" | |
], | |
"metadata": { | |
"id": "Z2_YmkUQfAB2" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"training_args = TrainingArguments(\n", | |
" output_dir=\"./train_results\",\n", | |
" learning_rate=2e-4,\n", | |
" per_device_train_batch_size=4,\n", | |
" gradient_accumulation_steps=4,\n", | |
" per_device_eval_batch_size=16,\n", | |
" num_train_epochs=1,\n", | |
" logging_strategy='steps',\n", | |
" logging_steps=10,\n", | |
" save_strategy='epoch',\n", | |
" evaluation_strategy='epoch',\n", | |
" load_best_model_at_end=True,\n", | |
" metric_for_best_model=\"eval_loss\",\n", | |
" greater_is_better=False,\n", | |
" save_total_limit=2\n", | |
")\n", | |
"\n", | |
"trainer = Trainer(\n", | |
" model=model,\n", | |
" args=training_args,\n", | |
" train_dataset=train_dataset,\n", | |
" eval_dataset=val_dataset,\n", | |
" data_collator=get_collator(tokenizer, 512)\n", | |
")\n", | |
"\n", | |
"trainer.train()\n", | |
"model = trainer.model\n", | |
"model.save_pretrained(peft_model_name)" | |
], | |
"metadata": { | |
"id": "adLmEfCBS_p1" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# 学習済みモデルのロード" | |
], | |
"metadata": { | |
"id": "RLFAZ7JZfFfq" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model = PeftModel.from_pretrained(base_model, peft_model_name)" | |
], | |
"metadata": { | |
"id": "6s5ELSLJUlvL" | |
}, | |
"execution_count": 5, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# 推論の実行" | |
], | |
"metadata": { | |
"id": "txxWOXkgfL2B" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"prompt = prompt_template_oqa.format(instruction=\"日本で人気のスポーツは? \")\n", | |
"\n", | |
"inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n", | |
"with torch.no_grad():\n", | |
" tokens = model.generate(\n", | |
" **inputs,\n", | |
" max_new_tokens=128,\n", | |
" repetition_penalty=1.1\n", | |
" )\n", | |
"\n", | |
"output = tokenizer.decode(tokens[0], skip_special_tokens=True)\n", | |
"print(output)" | |
], | |
"metadata": { | |
"id": "2dt1gC5-eK80" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [], | |
"metadata": { | |
"id": "zhizz9VleT9h" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment