Skip to content

Instantly share code, notes, and snippets.

@niw
Created May 21, 2023 03:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save niw/b8e8509147e27b943b4d9af01bea91cf to your computer and use it in GitHub Desktop.
Save niw/b8e8509147e27b943b4d9af01bea91cf to your computer and use it in GitHub Desktop.
OpenCALM-7B を PEFT を使って LoRA でファインチューンする.ipynb
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/niw/b8e8509147e27b943b4d9af01bea91cf/opencalm-7b-peft-lora.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"source": [
"### GPU に何使ってるか調べる\n",
"\n",
"A100 じゃないとうまく学習できなくて、A100 を使うには課金が必要です。\n",
"推論だけなら、T4でも大丈夫。"
],
"metadata": {
"id": "5fegbGR7X0s9"
}
},
{
"cell_type": "code",
"source": [
"!nvidia-smi"
],
"metadata": {
"id": "cCcP0XcWWnZm"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "TfBzP8gWzkpv"
},
"source": [
"### Transformers と PEFT と依存ライブラリをインストール\n",
"\n",
"必要なライブラリをインストール。\n",
"- bitsandbytes は load_in_8bit に必要。\n",
"- datasets はのちに学習用データを取ってくるのに必要。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "otj46qRbtpnd"
},
"outputs": [],
"source": [
"!pip install -q bitsandbytes datasets accelerate loralib\n",
"!pip install -q git+https://github.com/huggingface/transformers.git@main\n",
"!pip install -q git+https://github.com/huggingface/peft.git"
]
},
{
"cell_type": "markdown",
"source": [
"## 学習する"
],
"metadata": {
"id": "2hsFhshbRjlb"
}
},
{
"cell_type": "markdown",
"source": [
"### つかうベースモデルの名前を設定\n",
"\n",
"Cyberagent の OpenCALM-7B をベースに、LoRA します。"
],
"metadata": {
"id": "21yOhJckCbr4"
}
},
{
"cell_type": "code",
"source": [
"MODEL_NAME = \"cyberagent/open-calm-7b\"\n",
"PEFT_MODEL_NAME = \"open-calm-7b-lora\""
],
"metadata": {
"id": "r-d0iWDexJ1y"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "FOtwYRI3zzXI"
},
"source": [
"### ベースのモデルを読み込む\n",
"\n",
"Huggingface から公開されているモデルをダウンロードするので時間がかかる。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cg3fiQOvmI3Q"
},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
"import bitsandbytes as bnb\n",
"from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM\n",
"\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" MODEL_NAME, \n",
" load_in_8bit=True, \n",
" device_map=\"auto\",\n",
")\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)"
]
},
{
"cell_type": "markdown",
"source": [
"モデルの状態を見たければ。"
],
"metadata": {
"id": "jZfjMEQPODp1"
}
},
{
"cell_type": "code",
"source": [
"for item in model.state_dict().items():\n",
" print(item[0])"
],
"metadata": {
"id": "SVMeM_qVxxxi"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### 試しに推論してみる\n",
"\n",
"LoRA前どんな感じで推論されるか試す"
],
"metadata": {
"id": "QIBYuKhcYTNx"
}
},
{
"cell_type": "code",
"source": [
"inputs = tokenizer(\"かわいい子猫ちゃんと\", return_tensors=\"pt\").to(model.device)\n",
"\n",
"with torch.no_grad():\n",
" tokens = model.generate(\n",
" **inputs,\n",
" max_new_tokens=128,\n",
" do_sample=True,\n",
" temperature=0.7,\n",
" top_p=0.75,\n",
" top_k=40,\n",
" repetition_penalty=5.0,\n",
" pad_token_id=tokenizer.pad_token_id,\n",
" )\n",
"\n",
"print(tokenizer.decode(tokens[0], skip_special_tokens=True))"
],
"metadata": {
"id": "DgNYJzcEYQPa"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "9fTSZntA1iUG"
},
"source": [
"### 学習ようにモデルを調整する\n",
"\n",
"この辺は Hugging Face のオリジナルのコードと同様"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "T-gy-LxM0yAi"
},
"outputs": [],
"source": [
"for param in model.parameters():\n",
" param.requires_grad = False # freeze the model - train adapters later\n",
" if param.ndim == 1:\n",
" # cast the small parameters (e.g. layernorm) to fp32 for stability\n",
" param.data = param.data.to(torch.float32)\n",
"\n",
"model.gradient_checkpointing_enable() # reduce number of stored activations\n",
"model.enable_input_require_grads()\n",
"\n",
"class CastOutputToFloat(nn.Sequential):\n",
" def forward(self, x):\n",
" return super().forward(x).to(torch.float32)\n",
"model.embed_out = CastOutputToFloat(model.embed_out)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KwOTr7B3NlM3"
},
"source": [
"PEFT の LoRA を適用する"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4iwHGzKBN6wk"
},
"outputs": [],
"source": [
"from peft import LoraConfig, get_peft_model, PeftType, TaskType\n",
"\n",
"config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" r=8,\n",
" lora_alpha=32,\n",
" lora_dropout=0.1,\n",
" # OpenCALM-7B は GPT-NeoX なので、ターゲットモジュールは `query_key_value`\n",
" target_modules=[\"query_key_value\"],\n",
")\n",
"\n",
"model = get_peft_model(model, config)\n",
"\n",
"def print_trainable_parameters(model):\n",
" \"\"\"\n",
" Prints the number of trainable parameters in the model.\n",
" \"\"\"\n",
" trainable_params = 0\n",
" all_param = 0\n",
" for _, param in model.named_parameters():\n",
" all_param += param.numel()\n",
" if param.requires_grad:\n",
" trainable_params += param.numel()\n",
" print(\n",
" f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n",
" )\n",
"\n",
"print_trainable_parameters(model)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QdjWif4CVXR6"
},
"source": [
"### 学習用データを用意する\n",
"\n",
"`kunishou/databricks-dolly-15k-ja` の質問と回答のセットを使って学習する。\n",
"- データを質問と回答のフォーマットに変換\n",
"- 最大 2048 トークンなのでそれを超えるのは無理なので除外"
]
},
{
"cell_type": "code",
"source": [
"import transformers\n",
"from datasets import load_dataset\n",
"\n",
"def convert(item):\n",
" instruction = item['instruction']\n",
" output = item['output']\n",
" return tokenizer(f\"### Instruction:\\n{instruction}\\n\\n### Response:\\n{output}\")\n",
"\n",
"data = load_dataset(\"kunishou/databricks-dolly-15k-ja\")\n",
"data = data \\\n",
" .map(convert, remove_columns=[\"input\", \"instruction\", \"output\", \"category\", \"index\"]) \\\n",
" .filter(lambda item: len(item[\"input_ids\"]) <= 2048)\n"
],
"metadata": {
"id": "Gea8P85GYuPd"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"### 学習する\n",
"\n",
"A100(40G) で15分くらいかかると思う。メモリはギリギリ。"
],
"metadata": {
"id": "kAp-hsK23Ngd"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AQ_HCYruWIHU"
},
"outputs": [],
"source": [
"trainer = transformers.Trainer(\n",
" model=model, \n",
" train_dataset=data['train'],\n",
" args=transformers.TrainingArguments(\n",
" per_device_train_batch_size=8, \n",
" gradient_accumulation_steps=4,\n",
" warmup_steps=50,\n",
" max_steps=100,\n",
" learning_rate=2e-4,\n",
" fp16=True,\n",
" logging_steps=1,\n",
" output_dir='outputs'\n",
" ),\n",
" data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)\n",
")\n",
"model.config.use_cache = False # silence the warnings. Please re-enable for inference!\n",
"trainer.train()"
]
},
{
"cell_type": "markdown",
"source": [
"### 学習結果を Google Drive に保存\n",
"\n",
"Google Drive をマウントしていなかったらマウントして、保存作のディレクトリを作成。"
],
"metadata": {
"id": "GoWY2FDIxNkl"
}
},
{
"cell_type": "code",
"source": [
"from datetime import datetime\n",
"now = datetime.now().strftime(\"%Y%m%d%H%M%S\")\n",
"\n",
"DRIVE_MOUNT_PATH = \"/content/drive\"\n",
"PEFT_MODEL_PATH = f\"{DRIVE_MOUNT_PATH}/MyDrive/Colab Data/{PEFT_MODEL_NAME}-{now}\"\n",
"\n",
"PEFT_MODEL_PATH"
],
"metadata": {
"id": "m-rUmHUJlxh5"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from google.colab import drive\n",
"import os\n",
"\n",
"if not os.path.exists(DRIVE_MOUNT_PATH):\n",
" drive.mount(DRIVE_MOUNT_PATH)\n",
"\n",
"!mkdir -p \"{PEFT_MODEL_PATH}\""
],
"metadata": {
"id": "OqaL1P1nXKbh"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"model.save_pretrained(PEFT_MODEL_PATH)"
],
"metadata": {
"id": "oI7JJ8GpZ3Vj"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# 学習結果を使う"
],
"metadata": {
"id": "Apd_TFj2Reen"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "S65GcxNGA9kz"
},
"source": [
"### 学習済みの結果を Google Drive から読み込む\n"
]
},
{
"cell_type": "code",
"source": [
"# すでに学習済みのデータがあれば使う\n",
"# 上からの続きなら実行しなくて良い\n",
"\n",
"from google.colab import drive\n",
"import os\n",
"\n",
"if not 'DRIVE_MOUNT_PATH' in locals():\n",
" DRIVE_MOUNT_PATH = \"/content/drive\"\n",
"\n",
"if not 'PEFT_MODEL_PATH' in locals():\n",
" PEFT_MODEL_PATH = f\"{DRIVE_MOUNT_PATH}/MyDrive/Colab Data/open-calm-7b-lora-20230520073602\"\n",
"\n",
"if not os.path.exists(DRIVE_MOUNT_PATH):\n",
" drive.mount(DRIVE_MOUNT_PATH)"
],
"metadata": {
"id": "zqGQRM_URzYZ"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import torch\n",
"from peft import PeftModel, PeftConfig\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"\n",
"config = PeftConfig.from_pretrained(PEFT_MODEL_PATH)\n",
"model = AutoModelForCausalLM.from_pretrained(\n",
" config.base_model_name_or_path,\n",
" load_in_8bit=True,\n",
" device_map='auto'\n",
")\n",
"tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n",
"\n",
"# ここでオリジナルのモデルに学習結果のLoRAを追加\n",
"model = PeftModel.from_pretrained(model, PEFT_MODEL_PATH)"
],
"metadata": {
"id": "4QJFMA92Zsnm"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "MHYljmTjj5wX"
},
"source": [
"### LoRAつきで推論する\n",
"\n",
"ちゃんと質問にわりと答えるようになってる。"
]
},
{
"cell_type": "code",
"source": [
"instruction = \"猫の鳴き方を教えて\"\n",
"\n",
"inputs = tokenizer(f\"### Instruction:\\n{instruction}\\n\\n### Response:\\n\", return_tensors='pt').to(model.device)\n",
"\n",
"with torch.no_grad():\n",
" tokens = model.generate(\n",
" **inputs,\n",
" max_new_tokens=128,\n",
" do_sample=True,\n",
" temperature=0.7,\n",
" top_p=0.75,\n",
" top_k=40,\n",
" repetition_penalty=5.0,\n",
" pad_token_id=tokenizer.pad_token_id,\n",
" )\n",
"\n",
"print(tokenizer.decode(tokens[0], skip_special_tokens=True))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WfRhpLrnaaj-",
"outputId": "66eb8abc-c05c-4cb2-a7c0-1bae824d08f5"
},
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"### Instruction:\n",
"猫の鳴き方を教えて\n",
"\n",
"### Response:\n",
"この「ネコ」は、自分の家に住み着いた野良の猫です。 彼らはしばしば鳴くので、「ニャー」、「ミャオ」「ウォウ」、または単に 「ワンワン」(犬のように)と聞こえます 。 これは子守唄であり、\"cock-a\"や \"doggie woofers\"(Duck Down the Lake など)、あるいは単純に 'Cat'とも呼ばれます(Nancy Pelman のアルバムの名前にもなっています)。 この歌には2つの異なるバージョンがあり、(1)「アーニー・ピューーズ」。 ( 2 )この歌で使われている言葉である “I'm called\n"
]
}
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"machine_shape": "hm",
"provenance": [],
"collapsed_sections": [
"2hsFhshbRjlb"
],
"gpuType": "A100",
"include_colab_link": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment