Created
May 21, 2023 03:23
-
-
Save niw/b8e8509147e27b943b4d9af01bea91cf to your computer and use it in GitHub Desktop.
OpenCALM-7B を PEFT を使って LoRA でファインチューンする.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/niw/b8e8509147e27b943b4d9af01bea91cf/opencalm-7b-peft-lora.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### GPU に何使ってるか調べる\n", | |
"\n", | |
"A100 じゃないとうまく学習できなくて、A100 を使うには課金が必要です。\n", | |
"推論だけなら、T4でも大丈夫。" | |
], | |
"metadata": { | |
"id": "5fegbGR7X0s9" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"!nvidia-smi" | |
], | |
"metadata": { | |
"id": "cCcP0XcWWnZm" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "TfBzP8gWzkpv" | |
}, | |
"source": [ | |
"### Transformers と PEFT と依存ライブラリをインストール\n", | |
"\n", | |
"必要なライブラリをインストール。\n", | |
"- bitsandbytes は load_in_8bit に必要。\n", | |
"- datasets はのちに学習用データを取ってくるのに必要。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "otj46qRbtpnd" | |
}, | |
"outputs": [], | |
"source": [ | |
"!pip install -q bitsandbytes datasets accelerate loralib\n", | |
"!pip install -q git+https://github.com/huggingface/transformers.git@main\n", | |
"!pip install -q git+https://github.com/huggingface/peft.git" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"## 学習する" | |
], | |
"metadata": { | |
"id": "2hsFhshbRjlb" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### つかうベースモデルの名前を設定\n", | |
"\n", | |
"Cyberagent の OpenCALM-7B をベースに、LoRA します。" | |
], | |
"metadata": { | |
"id": "21yOhJckCbr4" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"MODEL_NAME = \"cyberagent/open-calm-7b\"\n", | |
"PEFT_MODEL_NAME = \"open-calm-7b-lora\"" | |
], | |
"metadata": { | |
"id": "r-d0iWDexJ1y" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "FOtwYRI3zzXI" | |
}, | |
"source": [ | |
"### ベースのモデルを読み込む\n", | |
"\n", | |
"Huggingface から公開されているモデルをダウンロードするので時間がかかる。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "cg3fiQOvmI3Q" | |
}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", | |
"\n", | |
"import torch\n", | |
"import torch.nn as nn\n", | |
"import bitsandbytes as bnb\n", | |
"from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM\n", | |
"\n", | |
"model = AutoModelForCausalLM.from_pretrained(\n", | |
" MODEL_NAME, \n", | |
" load_in_8bit=True, \n", | |
" device_map=\"auto\",\n", | |
")\n", | |
"\n", | |
"tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"モデルの状態を見たければ。" | |
], | |
"metadata": { | |
"id": "jZfjMEQPODp1" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"for item in model.state_dict().items():\n", | |
" print(item[0])" | |
], | |
"metadata": { | |
"id": "SVMeM_qVxxxi" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 試しに推論してみる\n", | |
"\n", | |
"LoRA前どんな感じで推論されるか試す" | |
], | |
"metadata": { | |
"id": "QIBYuKhcYTNx" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"inputs = tokenizer(\"かわいい子猫ちゃんと\", return_tensors=\"pt\").to(model.device)\n", | |
"\n", | |
"with torch.no_grad():\n", | |
" tokens = model.generate(\n", | |
" **inputs,\n", | |
" max_new_tokens=128,\n", | |
" do_sample=True,\n", | |
" temperature=0.7,\n", | |
" top_p=0.75,\n", | |
" top_k=40,\n", | |
" repetition_penalty=5.0,\n", | |
" pad_token_id=tokenizer.pad_token_id,\n", | |
" )\n", | |
"\n", | |
"print(tokenizer.decode(tokens[0], skip_special_tokens=True))" | |
], | |
"metadata": { | |
"id": "DgNYJzcEYQPa" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "9fTSZntA1iUG" | |
}, | |
"source": [ | |
"### 学習ようにモデルを調整する\n", | |
"\n", | |
"この辺は Hugging Face のオリジナルのコードと同様" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "T-gy-LxM0yAi" | |
}, | |
"outputs": [], | |
"source": [ | |
"for param in model.parameters():\n", | |
" param.requires_grad = False # freeze the model - train adapters later\n", | |
" if param.ndim == 1:\n", | |
" # cast the small parameters (e.g. layernorm) to fp32 for stability\n", | |
" param.data = param.data.to(torch.float32)\n", | |
"\n", | |
"model.gradient_checkpointing_enable() # reduce number of stored activations\n", | |
"model.enable_input_require_grads()\n", | |
"\n", | |
"class CastOutputToFloat(nn.Sequential):\n", | |
" def forward(self, x):\n", | |
" return super().forward(x).to(torch.float32)\n", | |
"model.embed_out = CastOutputToFloat(model.embed_out)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "KwOTr7B3NlM3" | |
}, | |
"source": [ | |
"PEFT の LoRA を適用する" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "4iwHGzKBN6wk" | |
}, | |
"outputs": [], | |
"source": [ | |
"from peft import LoraConfig, get_peft_model, PeftType, TaskType\n", | |
"\n", | |
"config = LoraConfig(\n", | |
" task_type=TaskType.CAUSAL_LM,\n", | |
" r=8,\n", | |
" lora_alpha=32,\n", | |
" lora_dropout=0.1,\n", | |
" # OpenCALM-7B は GPT-NeoX なので、ターゲットモジュールは `query_key_value`\n", | |
" target_modules=[\"query_key_value\"],\n", | |
")\n", | |
"\n", | |
"model = get_peft_model(model, config)\n", | |
"\n", | |
"def print_trainable_parameters(model):\n", | |
" \"\"\"\n", | |
" Prints the number of trainable parameters in the model.\n", | |
" \"\"\"\n", | |
" trainable_params = 0\n", | |
" all_param = 0\n", | |
" for _, param in model.named_parameters():\n", | |
" all_param += param.numel()\n", | |
" if param.requires_grad:\n", | |
" trainable_params += param.numel()\n", | |
" print(\n", | |
" f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n", | |
" )\n", | |
"\n", | |
"print_trainable_parameters(model)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "QdjWif4CVXR6" | |
}, | |
"source": [ | |
"### 学習用データを用意する\n", | |
"\n", | |
"`kunishou/databricks-dolly-15k-ja` の質問と回答のセットを使って学習する。\n", | |
"- データを質問と回答のフォーマットに変換\n", | |
"- 最大 2048 トークンなのでそれを超えるのは無理なので除外" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import transformers\n", | |
"from datasets import load_dataset\n", | |
"\n", | |
"def convert(item):\n", | |
" instruction = item['instruction']\n", | |
" output = item['output']\n", | |
" return tokenizer(f\"### Instruction:\\n{instruction}\\n\\n### Response:\\n{output}\")\n", | |
"\n", | |
"data = load_dataset(\"kunishou/databricks-dolly-15k-ja\")\n", | |
"data = data \\\n", | |
" .map(convert, remove_columns=[\"input\", \"instruction\", \"output\", \"category\", \"index\"]) \\\n", | |
" .filter(lambda item: len(item[\"input_ids\"]) <= 2048)\n" | |
], | |
"metadata": { | |
"id": "Gea8P85GYuPd" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 学習する\n", | |
"\n", | |
"A100(40G) で15分くらいかかると思う。メモリはギリギリ。" | |
], | |
"metadata": { | |
"id": "kAp-hsK23Ngd" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"id": "AQ_HCYruWIHU" | |
}, | |
"outputs": [], | |
"source": [ | |
"trainer = transformers.Trainer(\n", | |
" model=model, \n", | |
" train_dataset=data['train'],\n", | |
" args=transformers.TrainingArguments(\n", | |
" per_device_train_batch_size=8, \n", | |
" gradient_accumulation_steps=4,\n", | |
" warmup_steps=50,\n", | |
" max_steps=100,\n", | |
" learning_rate=2e-4,\n", | |
" fp16=True,\n", | |
" logging_steps=1,\n", | |
" output_dir='outputs'\n", | |
" ),\n", | |
" data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)\n", | |
")\n", | |
"model.config.use_cache = False # silence the warnings. Please re-enable for inference!\n", | |
"trainer.train()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"### 学習結果を Google Drive に保存\n", | |
"\n", | |
"Google Drive をマウントしていなかったらマウントして、保存作のディレクトリを作成。" | |
], | |
"metadata": { | |
"id": "GoWY2FDIxNkl" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from datetime import datetime\n", | |
"now = datetime.now().strftime(\"%Y%m%d%H%M%S\")\n", | |
"\n", | |
"DRIVE_MOUNT_PATH = \"/content/drive\"\n", | |
"PEFT_MODEL_PATH = f\"{DRIVE_MOUNT_PATH}/MyDrive/Colab Data/{PEFT_MODEL_NAME}-{now}\"\n", | |
"\n", | |
"PEFT_MODEL_PATH" | |
], | |
"metadata": { | |
"id": "m-rUmHUJlxh5" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from google.colab import drive\n", | |
"import os\n", | |
"\n", | |
"if not os.path.exists(DRIVE_MOUNT_PATH):\n", | |
" drive.mount(DRIVE_MOUNT_PATH)\n", | |
"\n", | |
"!mkdir -p \"{PEFT_MODEL_PATH}\"" | |
], | |
"metadata": { | |
"id": "OqaL1P1nXKbh" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"model.save_pretrained(PEFT_MODEL_PATH)" | |
], | |
"metadata": { | |
"id": "oI7JJ8GpZ3Vj" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# 学習結果を使う" | |
], | |
"metadata": { | |
"id": "Apd_TFj2Reen" | |
} | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "S65GcxNGA9kz" | |
}, | |
"source": [ | |
"### 学習済みの結果を Google Drive から読み込む\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# すでに学習済みのデータがあれば使う\n", | |
"# 上からの続きなら実行しなくて良い\n", | |
"\n", | |
"from google.colab import drive\n", | |
"import os\n", | |
"\n", | |
"if not 'DRIVE_MOUNT_PATH' in locals():\n", | |
" DRIVE_MOUNT_PATH = \"/content/drive\"\n", | |
"\n", | |
"if not 'PEFT_MODEL_PATH' in locals():\n", | |
" PEFT_MODEL_PATH = f\"{DRIVE_MOUNT_PATH}/MyDrive/Colab Data/open-calm-7b-lora-20230520073602\"\n", | |
"\n", | |
"if not os.path.exists(DRIVE_MOUNT_PATH):\n", | |
" drive.mount(DRIVE_MOUNT_PATH)" | |
], | |
"metadata": { | |
"id": "zqGQRM_URzYZ" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"import torch\n", | |
"from peft import PeftModel, PeftConfig\n", | |
"from transformers import AutoModelForCausalLM, AutoTokenizer\n", | |
"\n", | |
"config = PeftConfig.from_pretrained(PEFT_MODEL_PATH)\n", | |
"model = AutoModelForCausalLM.from_pretrained(\n", | |
" config.base_model_name_or_path,\n", | |
" load_in_8bit=True,\n", | |
" device_map='auto'\n", | |
")\n", | |
"tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)\n", | |
"\n", | |
"# ここでオリジナルのモデルに学習結果のLoRAを追加\n", | |
"model = PeftModel.from_pretrained(model, PEFT_MODEL_PATH)" | |
], | |
"metadata": { | |
"id": "4QJFMA92Zsnm" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "MHYljmTjj5wX" | |
}, | |
"source": [ | |
"### LoRAつきで推論する\n", | |
"\n", | |
"ちゃんと質問にわりと答えるようになってる。" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"instruction = \"猫の鳴き方を教えて\"\n", | |
"\n", | |
"inputs = tokenizer(f\"### Instruction:\\n{instruction}\\n\\n### Response:\\n\", return_tensors='pt').to(model.device)\n", | |
"\n", | |
"with torch.no_grad():\n", | |
" tokens = model.generate(\n", | |
" **inputs,\n", | |
" max_new_tokens=128,\n", | |
" do_sample=True,\n", | |
" temperature=0.7,\n", | |
" top_p=0.75,\n", | |
" top_k=40,\n", | |
" repetition_penalty=5.0,\n", | |
" pad_token_id=tokenizer.pad_token_id,\n", | |
" )\n", | |
"\n", | |
"print(tokenizer.decode(tokens[0], skip_special_tokens=True))" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "WfRhpLrnaaj-", | |
"outputId": "66eb8abc-c05c-4cb2-a7c0-1bae824d08f5" | |
}, | |
"execution_count": 7, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"### Instruction:\n", | |
"猫の鳴き方を教えて\n", | |
"\n", | |
"### Response:\n", | |
"この「ネコ」は、自分の家に住み着いた野良の猫です。 彼らはしばしば鳴くので、「ニャー」、「ミャオ」「ウォウ」、または単に 「ワンワン」(犬のように)と聞こえます 。 これは子守唄であり、\"cock-a\"や \"doggie woofers\"(Duck Down the Lake など)、あるいは単純に 'Cat'とも呼ばれます(Nancy Pelman のアルバムの名前にもなっています)。 この歌には2つの異なるバージョンがあり、(1)「アーニー・ピューーズ」。 ( 2 )この歌で使われている言葉である “I'm called\n" | |
] | |
} | |
] | |
} | |
], | |
"metadata": { | |
"accelerator": "GPU", | |
"colab": { | |
"machine_shape": "hm", | |
"provenance": [], | |
"collapsed_sections": [ | |
"2hsFhshbRjlb" | |
], | |
"gpuType": "A100", | |
"include_colab_link": true | |
}, | |
"gpuClass": "standard", | |
"kernelspec": { | |
"display_name": "Python 3 (ipykernel)", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 0 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment