Skip to content

Instantly share code, notes, and snippets.

@alexweberk
Last active March 12, 2024 00:24
Show Gist options
  • Save alexweberk/1434c95c05463866491677aac6ce19ba to your computer and use it in GitHub Desktop.
Save alexweberk/1434c95c05463866491677aac6ce19ba to your computer and use it in GitHub Desktop.
LoRA Fine-tuning Gemma with MLX
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 最新の Google Gemma モデルを MLX を使ってローカルでファインチューニング\n",
"\n",
"今回は、最新の [Google Gemma モデル](https://blog.google/technology/developers/gemma-open-models/)を Apple Silicon に最適化されたライブラリ `MLX` を使ってローカルで実行したり、ファインチューニングしてみましたのでその手順を紹介します。\n",
"MLX 関連の情報はドキュメンテーションが分かりづらいものも多かったので色々試した経緯も共有しながら少しでも何かの参考になれば幸いです。\n",
"\n",
"実際に使った Jupyter Notebook を Gist にアップロードしていますので、そちらも参考にしてください。\n",
"\n",
"- [Google Gemma モデルを MLX を使ってローカルでファインチューニング](https://gist.github.com/alexweberk/1434c95c05463866491677aac6ce19ba)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 事前準備\n",
"\n",
"必要なライブラリをインストールします。\n",
"また Apple Silicon 搭載の Mac が必要です。今回は M3 Max 128GB 搭載の MacBook Pro で実行しました。\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[33mDEPRECATION: Loading egg at /Users/alexishida/miniforge3/envs/py311/lib/python3.11/site-packages/argparse-1.4.0-py3.11.egg is deprecated. pip 24.3 will enforce this behaviour change. A possible replacement is to use pip for package installation.. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n",
"\u001b[0m\u001b[33mDEPRECATION: Loading egg at /Users/alexishida/miniforge3/envs/py311/lib/python3.11/site-packages/powerline_shell-0.7.0-py3.11.egg is deprecated. pip 24.3 will enforce this behaviour change. A possible replacement is to use pip for package installation.. Discussion can be found at https://github.com/pypa/pip/issues/12330\u001b[0m\u001b[33m\n",
"\u001b[0mRequirement already satisfied: mlx in /Users/alexishida/miniforge3/envs/py311/lib/python3.11/site-packages (0.4.0)\n",
"Requirement already satisfied: mlx_lm in /Users/alexishida/miniforge3/envs/py311/lib/python3.11/site-packages (0.0.13)\n",
"Requirement already satisfied: transformers in /Users/alexishida/miniforge3/envs/py311/lib/python3.11/site-packages (4.38.1)\n",
"Requirement already satisfied: numpy in /Users/alexishida/miniforge3/envs/py311/lib/python3.11/site-packages (from mlx_lm) (1.24.4)\n",
"Requirement already satisfied: protobuf in /Users/alexishida/miniforge3/envs/py311/lib/python3.11/site-packages (from mlx_lm) (4.25.2)\n",
"Requirement already satisfied: pyyaml in /Users/alexishida/miniforge3/envs/py311/lib/python3.11/site-packages (from mlx_lm) (6.0.1)\n",
"Requirement already satisfied: filelock in /Users/alexishida/miniforge3/envs/py311/lib/python3.11/site-packages (from transformers) (3.13.1)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /Users/alexishida/miniforge3/envs/py311/lib/python3.11/site-packages (from transformers) (0.20.3)\n",
"Requirement already satisfied: packaging>=20.0 in /Users/alexishida/miniforge3/envs/py311/lib/python3.11/site-packages (from transformers) (23.2)\n",
"Requirement already satisfied: regex!=2019.12.17 in /Users/alexishida/miniforge3/envs/py311/lib/python3.11/site-packages (from transformers) (2023.12.25)\n",
"Requirement already satisfied: requests in /Users/alexishida/miniforge3/envs/py311/lib/python3.11/site-packages (from transformers) (2.31.0)\n",
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /Users/alexishida/miniforge3/envs/py311/lib/python3.11/site-packages (from transformers) (0.15.1)\n",
"Requirement already satisfied: safetensors>=0.4.1 in /Users/alexishida/miniforge3/envs/py311/lib/python3.11/site-packages (from transformers) (0.4.2)\n",
"Requirement already satisfied: tqdm>=4.27 in /Users/alexishida/miniforge3/envs/py311/lib/python3.11/site-packages (from transformers) (4.66.1)\n",
"Requirement already satisfied: fsspec>=2023.5.0 in /Users/alexishida/miniforge3/envs/py311/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.19.3->transformers) (2024.2.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /Users/alexishida/miniforge3/envs/py311/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.19.3->transformers) (4.9.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /Users/alexishida/miniforge3/envs/py311/lib/python3.11/site-packages (from requests->transformers) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /Users/alexishida/miniforge3/envs/py311/lib/python3.11/site-packages (from requests->transformers) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/alexishida/miniforge3/envs/py311/lib/python3.11/site-packages (from requests->transformers) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /Users/alexishida/miniforge3/envs/py311/lib/python3.11/site-packages (from requests->transformers) (2024.2.2)\n"
]
}
],
"source": [
"!pip install -U mlx mlx_lm transformers"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## MLX を使ったモデルの実行\n",
"\n",
"公開された Gemma には 4 つほどバージョンがありますが、今回は instruction チューニング済みの `gemma-7b-it` を使ってみました。\n",
"\n",
"mlx バックエンドを活用した `mlx_lm` ライブラリを使います。\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "87a449fddad34bdda1582582be8ef8fc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Fetching 11 files: 0%| | 0/11 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from mlx_lm import load\n",
"\n",
"model, tokenizer = load(\"google/gemma-7b-it\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"一応英語を中心としたモデルなので、まずは英語が問題なく生成できるか試してみます。\n",
"\n",
"MLX での生成の場合、プロンプトテンプレートをどうすればよいのか検索してもわかりませんでした。\n",
"ただ、 `mlx-examples` のリポのコードを読む限りは `transformers` の tokenizer に `apply_chat_template` メソッドがある場合はそれを使ってくれているようでした。\n",
"そのため、生成の際には質問内容だけを含んだプロンプトをインプットとします。\n",
"\n",
"https://github.com/ml-explore/mlx-examples/blob/47dd6bd17f3cc7ef95672ea16e443e58ce5eb1bf/llms/mlx_lm/generate.py#L98\n"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'<bos><start_of_turn>user\\nWhy is the sky blue?<end_of_turn>\\n<start_of_turn>model\\n'"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# ライブラリ内でのテンプレート適用後のプロンプト\n",
"messages = [{\"role\": \"user\", \"content\": \"Why is the sky blue?\"}]\n",
"tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========\n",
"Prompt: Why is the sky blue?\n",
"\n",
"\n",
"The sky is blue due to a phenomenon called **Rayleigh Scattering**.\n",
"\n",
"Here's a breakdown of what happens:\n",
"\n",
"1. **Sunlight:** Sunrays are made up of all the colors of the rainbow, with each color having a different wavelength.\n",
"2. **Scattering:** When sunlight enters Earth's atmosphere, it interacts with the tiny particles of air (dust, water vapor, etc.). These particles scatter the sunlight in all directions.\n",
"3. **Blue Scatter:** The particles scatter the shorter wavelengths of blue and violet light more effectively than the longer wavelengths of red and orange light.\n",
"4. **Scattered Light:** The scattered light, which is predominantly blue, is scattered in all directions.\n",
"5. **Our View:** We see the scattered light from all directions, including the direction opposite the sun. This is why the sky appears blue.\n",
"\n",
"**Additional factors:**\n",
"\n",
"* **Time of Day:** The intensity of the blue color is strongest at midday and decreases as the sun gets closer to the horizon.\n",
"* **Clouds:** Clouds can reduce the amount of scattered light, making the sky appear white or gray.\n",
"* **Dust:** Dust particles can also scatter different colors of light, which can affect the appearance of the sky.\n",
"\n",
"\n",
"==========\n",
"Prompt: 16.696 tokens-per-sec\n",
"Generation: 18.644 tokens-per-sec\n"
]
}
],
"source": [
"# プロンプトテンプレート無しでの生成\n",
"prompt = \"\"\"\n",
"Why is the sky blue?\n",
"\"\"\".strip()\n",
"response = generate(\n",
" model,\n",
" tokenizer,\n",
" prompt=prompt,\n",
" verbose=True, # Set to True to see the prompt and response\n",
" temp=0.0,\n",
" max_tokens=256,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 日本語の生成\n",
"\n",
"英語がうまく生成できていそうだったので、次に日本語を生成してみます。\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========\n",
"Prompt: 空が青いのはなぜですか?\n",
"\n",
"\n",
"実際、空気は実際実際赤い色です。ただし、人間の目は赤い色を認識するには、特定の波長の光が必要です。空気の分子は、その特定の波長の光を吸収し、人間の目に届く残り色を青に見えます。\n",
"==========\n",
"Prompt: 19.429 tokens-per-sec\n",
"Generation: 17.932 tokens-per-sec\n"
]
}
],
"source": [
"# プロンプトテンプレートなしでの生成\n",
"prompt = \"\"\"\n",
"空が青いのはなぜですか?\n",
"\"\"\".strip()\n",
"response = generate(\n",
" model,\n",
" tokenizer,\n",
" prompt=prompt,\n",
" verbose=True, # Set to True to see the prompt and response\n",
" temp=0.0,\n",
" max_tokens=256,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========\n",
"Prompt: ## Instructions\n",
"You are a helpful AI assistant. Answer questions from the user to the best of your knowledge.\n",
"If you don't know the answer, be truthful in your answer.\n",
"Always answer in perfectly natural Japanese.\n",
"\n",
"## Questions\n",
"空が青いのはなぜですか?\n",
"\n",
"\n",
"## Answer\n",
"空が青い理由は、いくつかの原因があります。\n",
"\n",
"* **空気中の水蒸気:** 空中の水蒸気は、太陽によって加熱され、昇華し、雲が発生します。雲は空気の温度を下げ、空気中の粒子を小さくします。\n",
"* **太陽の光:** 太陽の光は、空気中の分子を励起し、空気の色を変化させます。\n",
"* **視覚の限界:** 人間は、実際よりも多くの色を認識することは不可能です。そのため、実際よりも明るい色や鮮度のある色を感じます。\n",
"==========\n",
"Prompt: 140.153 tokens-per-sec\n",
"Generation: 18.544 tokens-per-sec\n"
]
}
],
"source": [
"# よくあるテンプレートで指示文は英語の場合\n",
"prompt = \"\"\"\n",
"## Instructions\n",
"You are a helpful AI assistant. Answer questions from the user to the best of your knowledge.\n",
"If you don't know the answer, be truthful in your answer.\n",
"Always answer in perfectly natural Japanese.\n",
"\n",
"## Questions\n",
"空が青いのはなぜですか?\n",
"\"\"\".strip()\n",
"response = generate(\n",
" model,\n",
" tokenizer,\n",
" prompt=prompt,\n",
" verbose=True, # Set to True to see the prompt and response\n",
" temp=0.0,\n",
" max_tokens=256,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========\n",
"Prompt: ## 指示文\n",
"あなたは役に立つAIアシスタントです。ユーザーからの質問にできる限りの知識で答えます。\n",
"答えがわからない場合は、正直な答えをしてください。\n",
"解答は必ず自然な日本語で行ってください。\n",
"\n",
"## 質問\n",
"空が青いのはなぜですか?\n",
"\n",
"\n",
"## 解答\n",
"空気中の成分や雲々の影響によって、実際は実際は空は実際は青ではなく、実際は黒です。ただし、人間の視覚は、空気中の成分や雲々の影響によって、実際よりも明るい色を認識します。そのため、私たちが見た空は実際よりも明るく、そして青いように感じられます。\n",
"==========\n",
"Prompt: 144.546 tokens-per-sec\n",
"Generation: 18.491 tokens-per-sec\n"
]
}
],
"source": [
"# よくあるテンプレートで指示文も日本語の場合\n",
"prompt = \"\"\"\n",
"## 指示文\n",
"あなたは役に立つAIアシスタントです。ユーザーからの質問にできる限りの知識で答えます。\n",
"答えがわからない場合は、正直な答えをしてください。\n",
"解答は必ず自然な日本語で行ってください。\n",
"\n",
"## 質問\n",
"空が青いのはなぜですか?\n",
"\"\"\".strip()\n",
"response = generate(\n",
" model,\n",
" tokenizer,\n",
" prompt=prompt,\n",
" verbose=True, # Set to True to see the prompt and response\n",
" temp=0.0,\n",
" max_tokens=256,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"指示文を英語にしたことで実際英語の生成に近い構造の回答になったのが実際実際面白いですね。\n",
"\n",
"次は MLX を使ってファインチューニングをしてみます。\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## MLX を使って Gemma モデルを LoRA ファインチューニング\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"今回 [npaka さんの記事](https://note.com/npaka/n/nc55e44e407ff?sub_rt=share_h)を参考に「ござるデータセット」を使ってファインチューニングを行いました。\n",
"https://huggingface.co/datasets/bbz662bbz/databricks-dolly-15k-ja-gozarinnemon/tree/main\n",
"\n",
"[MLX Examples](https://github.com/ml-explore/mlx-examples/blob/main/lora/README.md#Custom-Data)を参考に、教師データのフォーマットを変えます。\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ファインチューニング用のデータの成形\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>input</th>\n",
" <th>instruction</th>\n",
" <th>output</th>\n",
" <th>category</th>\n",
" <th>index</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2</th>\n",
" <td></td>\n",
" <td>ラクダはなぜ水なしで長く生きられるのか?</td>\n",
" <td>我、りんえもんは思う。 ラクダは、長時間にわたってエネルギーと水分で満たされた状態を保つため...</td>\n",
" <td>open_qa</td>\n",
" <td>2</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td></td>\n",
" <td>アリスの両親には3人の娘がいる:エイミー、ジェシー、そして三女の名前は?</td>\n",
" <td>我、りんえもんは思う。 三女の名前はアリス知らんけど。</td>\n",
" <td>open_qa</td>\n",
" <td>3</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td></td>\n",
" <td>国連に本部を建てるためにニューヨークの土地を寄付したのは誰?</td>\n",
" <td>我、りんえもんは思う。 ジョン・デイヴィソン・ロックフェラー知らんけど。</td>\n",
" <td>open_qa</td>\n",
" <td>7</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td></td>\n",
" <td>多角形とは何ですか?</td>\n",
" <td>我、りんえもんは思う。 多角形は、幾何学における形の一つでござる。多角形は閉折れ線あるいは閉...</td>\n",
" <td>open_qa</td>\n",
" <td>13</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td></td>\n",
" <td>ミシェル・マクラーレンが監督した『ゲーム・オブ・スローンズ』シーズン4のエピソードは?</td>\n",
" <td>我、りんえもんは思う。 シーズン4の第4話「誓約を果たすもの」、第5話「新王誕生」をそれぞれ...</td>\n",
" <td>open_qa</td>\n",
" <td>15</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15000</th>\n",
" <td></td>\n",
" <td>troll the respawn, Jeremyとはどういう意味ですか?</td>\n",
" <td>我、りんえもんは思う。 Trolling the respawnとは、ビデオゲームにおいて、...</td>\n",
" <td>open_qa</td>\n",
" <td>15000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15001</th>\n",
" <td></td>\n",
" <td>2021年、チェスの世界選手権で優勝したのは?</td>\n",
" <td>我、りんえもんは思う。 マグナス・カールセンがイアン・ネポムニアッキを7.5 - 3.5で破...</td>\n",
" <td>open_qa</td>\n",
" <td>15001</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15002</th>\n",
" <td></td>\n",
" <td>なぜ毎日歩くといいのでしょうか?</td>\n",
" <td>我、りんえもんは思う。 ウォーキングは、関節に負担をかけずにカロリーを消費できる良い運動でご...</td>\n",
" <td>open_qa</td>\n",
" <td>15002</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15003</th>\n",
" <td></td>\n",
" <td>Pythonの生みの親は誰なのか?</td>\n",
" <td>我、りんえもんは思う。 Guido van RossumはPythonの父でござる。そして、...</td>\n",
" <td>open_qa</td>\n",
" <td>15003</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15012</th>\n",
" <td></td>\n",
" <td>ロードバイクとマウンテンバイクの違いは何ですか?</td>\n",
" <td>我、りんえもんは思う。 ロードバイクはアスファルトやセメントの上を走ることを想定して作られて...</td>\n",
" <td>open_qa</td>\n",
" <td>15012</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>3611 rows × 5 columns</p>\n",
"</div>"
],
"text/plain": [
" input instruction \\\n",
"2 ラクダはなぜ水なしで長く生きられるのか? \n",
"3 アリスの両親には3人の娘がいる:エイミー、ジェシー、そして三女の名前は? \n",
"7 国連に本部を建てるためにニューヨークの土地を寄付したのは誰? \n",
"13 多角形とは何ですか? \n",
"15 ミシェル・マクラーレンが監督した『ゲーム・オブ・スローンズ』シーズン4のエピソードは? \n",
"... ... ... \n",
"15000 troll the respawn, Jeremyとはどういう意味ですか? \n",
"15001 2021年、チェスの世界選手権で優勝したのは? \n",
"15002 なぜ毎日歩くといいのでしょうか? \n",
"15003 Pythonの生みの親は誰なのか? \n",
"15012 ロードバイクとマウンテンバイクの違いは何ですか? \n",
"\n",
" output category index \n",
"2 我、りんえもんは思う。 ラクダは、長時間にわたってエネルギーと水分で満たされた状態を保つため... open_qa 2 \n",
"3 我、りんえもんは思う。 三女の名前はアリス知らんけど。 open_qa 3 \n",
"7 我、りんえもんは思う。 ジョン・デイヴィソン・ロックフェラー知らんけど。 open_qa 7 \n",
"13 我、りんえもんは思う。 多角形は、幾何学における形の一つでござる。多角形は閉折れ線あるいは閉... open_qa 13 \n",
"15 我、りんえもんは思う。 シーズン4の第4話「誓約を果たすもの」、第5話「新王誕生」をそれぞれ... open_qa 15 \n",
"... ... ... ... \n",
"15000 我、りんえもんは思う。 Trolling the respawnとは、ビデオゲームにおいて、... open_qa 15000 \n",
"15001 我、りんえもんは思う。 マグナス・カールセンがイアン・ネポムニアッキを7.5 - 3.5で破... open_qa 15001 \n",
"15002 我、りんえもんは思う。 ウォーキングは、関節に負担をかけずにカロリーを消費できる良い運動でご... open_qa 15002 \n",
"15003 我、りんえもんは思う。 Guido van RossumはPythonの父でござる。そして、... open_qa 15003 \n",
"15012 我、りんえもんは思う。 ロードバイクはアスファルトやセメントの上を走ることを想定して作られて... open_qa 15012 \n",
"\n",
"[3611 rows x 5 columns]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"data = pd.read_json(\"data/databricks-dolly-15k-ja-gozarinnemon.json\")\n",
"data = data[data[\"category\"] == \"open_qa\"]\n",
"data"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"0"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# input カラムがからであることを確認\n",
"data[\"input\"].str.len().sum()"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<bos><start_of_turn>user\n",
"手品がうまくなるにはどうしたらいいのでしょうか?<end_of_turn>\n",
"<start_of_turn>model\n",
"我、りんえもんは思う。 マジックは、何時間も何時間も練習することで上達しますでござる。 マジックは、手の筋肉の記憶によって構築される非常に器用さを必要とするため、カードのデッキを入手し、その持ち方とフラリッシュのやり方を学ぶことは、始めるのに素晴らしい方法でござる。 マジックは、手の筋肉の記憶によって構築される非常に器用さを必要とするため、上手になるには何時間も何時間も繰り返し練習する必要がありますでござる。 マジックが上手になるには、マジックへの興味も重要な要素でござる。 もし興味がなければ、マジックがうまくなることはないでしょう。 また、クラブに参加し、それに興味を持っている他の人々と一緒にマジックのトリックを練習することができますでござる。 参加するのに最適な組織の1つは、アメリカのマジシャン協会(SAM)でござる。知らんけど。<end_of_turn><eos>\n"
]
}
],
"source": [
"def generate_prompt(example: pd.Series) -> str:\n",
" \"教師データをGemmaプロンプトの形にあわせるための関数\"\n",
" return \"\"\"<bos><start_of_turn>user\n",
"{}<end_of_turn>\n",
"<start_of_turn>model\n",
"{}<end_of_turn><eos>\"\"\".format(\n",
" example[\"instruction\"], example[\"output\"]\n",
" )\n",
"\n",
"\n",
"data[\"text\"] = data.apply(generate_prompt, axis=1)\n",
"print(data[\"text\"].iloc[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"トレーニングセットとバリデーションセットに分けてファイルに保存します。\n",
"\n",
"- トレーニングセット: `data/train.jsonl`\n",
"- バリデーションセット: `data/valid.jsonl`\n",
"\n",
"`mlx_lm.lora` では train.jsonl と valid.jsonl が必須のようでした。\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\"text\":\"<bos><start_of_turn>user\\n手品がうまくなるにはどうしたらいいのでしょうか?<end_of_turn>\\n<start_of_turn>model\\n我、りんえもんは思う。 マジックは、何時間も何時間も練習することで上達しますでござる。 マジックは、手の筋肉の記憶によって構築される非常に器用さを必要とするため、カードのデッキを入手し、その持ち方とフラリッシュのやり方を学ぶことは、始めるのに素晴らしい方法でござる。 マジックは、手の筋肉の記憶によって構築される非常に器用さを必要とするため、上手になるには何時間も何時間も繰り返し練習する必要がありますでござる。 マジックが上手になるには、マジックへの興味も重要な要素でござる。 もし興味がなければ、マジックがうまくなることはないでしょう。 また、クラブに参加し、それに興味を持っている他の人々と一緒にマジックのトリックを練習することができますでござる。 参加するのに最適な組織の1つは、アメリカのマジシャン協会(SAM)でござる。知らんけど。<end_of_turn><eos>\"}\n",
"{\"text\":\"<bos><start_of_turn>user\\nカエサルはどのようにしてローマで権力を握ったのでしょうか?<end_of_turn>\\n<start_of_turn>model\\n我、りんえもんは思う。 女神ヴィーナスとトロイの王子アイネアスの血を引くジュリイ家と呼ばれるパトリシアン(元老院議員)一族に生まれた。この家系に生まれたことで、彼は権力への道を他の誰よりも先に歩むことになった。\\n\\nカエサルは賢く、勤勉で魅力的な人物で、後世に役立つ多くの友人や人脈を作りましたでござる。当時ローマの独裁者であったガイウス・マリウスの親戚と結婚し、政治的なコネクションをさらに深めましたでござる。スッラがマリウスを破ったことで挫折を味わったが、カエサルはそれを逆手にとって、ローマに留まらなければならない赴任を逃れることができた。コンスル(ローマの指導者)を目指す者にとって、ローマのために土地を征服した有能な軍師として見られることは重要なことでござった。\\n\\nカエサルの次の重要な地位はアエディレで、主な職務のひとつはゲームの企画でござった。カエサルは、ローマの人々に我を知ってもらうために、多額の借金をし、豪華な競技を開催した。カエサルは、他の貴族と違って、当時平民が住んでいたスブラで育ったので、民衆の愛によって得られる重要性と権力を知っていた。\\n\\nカエサルは何度も任命されながら、その軍事的才能を発揮し、最終的には執政官に選ばれた。カエサルは、当時の有力者2人と同盟を結びました:ポンペイ(名将)、クラッスス(ローマ一の富豪)と同盟を結び、「第一次三頭政治」と呼ばれるようになりましたでござる。\\n\\nカエサルは領事職の後、ガリア総督となり、ローマに多くの領土を獲得させた。しかし、ローマの政敵たちは、カエサルを訴追し、権力を剥奪するための裁判を準備した。カエサルは軍団の一つを集め、ローマに進軍した。政敵はローマから戦術的に撤退し、カエサルは名目上の支配権を握った。カエサルは、ギリシャ、スペイン、エジプトでの戦争に勝利して権力を強化し、最終的にはローマでの支配を確固たるものにし、独裁者に選出された。知らんけど。<end_of_turn><eos>\"}\n",
"{\"text\":\"<bos><start_of_turn>user\\nResearch Collaboratory for Structural Bioinformatics Protein Data Bank (RCSB PDB)とは何ですか?<end_of_turn>\\n<start_of_turn>model\\n我、りんえもんは思う。 Research Collaboratory for Structural Bioinformatics Protein Data Bank (RCSB PDB) は、タンパク質や核酸などの高分子の立体構造について豊富な情報を提供するデータベースでござる。このデータベースには、多数の高分子の実験的に決定された原子座標が含まれており、高分子の構造、機能、相互作用を研究するために利用することができますでござる。RCSB PDBは、多くの生物学的プロセスの構造基盤を理解し、特定の高分子を標的とする新薬を設計するための貴重な資源となるため、ゲノミクス研究や創薬の分野で広く利用されていますでござる。\\n\\nRCSB PDBには、原子座標に加えて、構造決定に使用した実験方法、関連する科学文献の引用、生物学的機能、配列、相同性に関する情報など、各高分子に関する豊富な追加情報が含まれていますでござる。また、構造データの可視化、解析、ダウンロードのための様々なツールやリソースも提供されており、高分子の3次元構造をインタラクティブに探索できるウェブベースのビューアもありますでござる。知らんけど。<end_of_turn><eos>\"}\n",
"{\"text\":\"<bos><start_of_turn>user\\nロシア文学の名作をいくつか挙げる<end_of_turn>\\n<start_of_turn>model\\n我、りんえもんは思う。 * 戦争と平和、レオ・トルストイ。\\n* 罪と罰、フョードル・ドストエフスキー(Fyodor Dostoevsky)。\\n* エフゲニー・オネーギン、アレクサンドル・プーシキン。\\n* 死せる魂、ニコライ・ゴーゴリ。\\n* 災い転じて福となす、アレクサンドル・グリボエドフ。\\n* マスターとマルガリータ、ミハイル・ブルガーコフ。\\n* オブロモフ、イワン・ゴンチャロフ知らんけど。<end_of_turn><eos>\"}\n",
"{\"text\":\"<bos><start_of_turn>user\\nゴールデンステイト・ウォリアーズで一番好きな選手は誰ですか?<end_of_turn>\\n<start_of_turn>model\\n我、りんえもんは思う。 AIである我は、好きなバスケットボール選手はいませぬ。しかし、もし選ぶとしたら、ゲイリー・ペイトン2世を選びますでござる。彼はハートを持ってプレーし、執拗にディフェンスし、誰にでもダンクを決める。知らんけど。<end_of_turn><eos>\"}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
]
}
],
"source": [
"split_ix = int(len(data) * 0.9)\n",
"# shuffle data\n",
"data = data.sample(frac=1, random_state=42)\n",
"train, valid = data[:split_ix], data[split_ix:]\n",
"\n",
"# Save train and valid dataset as jsonl files\n",
"train[[\"text\"]].to_json(\"data/train.jsonl\", orient=\"records\", lines=True, force_ascii=False)\n",
"valid[[\"text\"]].to_json(\"data/valid.jsonl\", orient=\"records\", lines=True, force_ascii=False)\n",
"\n",
"!head -n 5 data/train.jsonl"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### MLX で LoRA ファインチューニングを実行\n",
"\n",
"`mlx_lm` で LoRA する場合、下記のコマンドでいろんなオプションが出せます。\n"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"usage: lora.py [-h] [--model MODEL] [--max-tokens MAX_TOKENS] [--temp TEMP]\n",
" [--prompt PROMPT] [--train] [--data DATA]\n",
" [--lora-layers LORA_LAYERS] [--batch-size BATCH_SIZE]\n",
" [--iters ITERS] [--val-batches VAL_BATCHES]\n",
" [--learning-rate LEARNING_RATE]\n",
" [--steps-per-report STEPS_PER_REPORT]\n",
" [--steps-per-eval STEPS_PER_EVAL]\n",
" [--resume-adapter-file RESUME_ADAPTER_FILE]\n",
" [--adapter-file ADAPTER_FILE] [--save-every SAVE_EVERY]\n",
" [--test] [--test-batches TEST_BATCHES]\n",
" [--max-seq-length MAX_SEQ_LENGTH] [--seed SEED]\n",
"\n",
"LoRA or QLoRA finetuning.\n",
"\n",
"options:\n",
" -h, --help show this help message and exit\n",
" --model MODEL The path to the local model directory or Hugging Face\n",
" repo.\n",
" --max-tokens MAX_TOKENS, -m MAX_TOKENS\n",
" The maximum number of tokens to generate\n",
" --temp TEMP The sampling temperature\n",
" --prompt PROMPT, -p PROMPT\n",
" The prompt for generation\n",
" --train Do training\n",
" --data DATA Directory with {train, valid, test}.jsonl files\n",
" --lora-layers LORA_LAYERS\n",
" Number of layers to fine-tune\n",
" --batch-size BATCH_SIZE\n",
" Minibatch size.\n",
" --iters ITERS Iterations to train for.\n",
" --val-batches VAL_BATCHES\n",
" Number of validation batches, -1 uses the entire\n",
" validation set.\n",
" --learning-rate LEARNING_RATE\n",
" Adam learning rate.\n",
" --steps-per-report STEPS_PER_REPORT\n",
" Number of training steps between loss reporting.\n",
" --steps-per-eval STEPS_PER_EVAL\n",
" Number of training steps between validations.\n",
" --resume-adapter-file RESUME_ADAPTER_FILE\n",
" Load path to resume training with the given adapter\n",
" weights.\n",
" --adapter-file ADAPTER_FILE\n",
" Save/load path for the trained adapter weights.\n",
" --save-every SAVE_EVERY\n",
" Save the model every N iterations.\n",
" --test Evaluate on the test set after training\n",
" --test-batches TEST_BATCHES\n",
" Number of test set batches, -1 uses the entire test\n",
" set.\n",
" --max-seq-length MAX_SEQ_LENGTH\n",
" Maximum sequence length.\n",
" --seed SEED The PRNG seed\n"
]
}
],
"source": [
"!python -m mlx_lm.lora --help"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"早速トレーニングしてみましょう。テストのため 600 ステップだけトレーニングします。\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading pretrained model\n",
"Fetching 11 files: 100%|█████████████████████| 11/11 [00:00<00:00, 32652.05it/s]\n",
"Total parameters 8539.516M\n",
"Trainable parameters 1.835M\n",
"Loading datasets\n",
"Training\n",
"Starting training..., iters: 600\n",
"Iter 1: Val loss 11.819, Val took 99.751s\n",
"Iter 10: Train loss 11.533, Learning Rate 1.000e-05, It/sec 0.212, Tokens/sec 105.547, Trained Tokens 4969\n",
"Iter 20: Train loss 8.003, Learning Rate 1.000e-05, It/sec 0.050, Tokens/sec 33.970, Trained Tokens 11712\n",
"Iter 30: Train loss 7.366, Learning Rate 1.000e-05, It/sec 0.272, Tokens/sec 128.895, Trained Tokens 16459\n",
"Iter 40: Train loss 5.418, Learning Rate 1.000e-05, It/sec 0.278, Tokens/sec 121.127, Trained Tokens 20809\n",
"Iter 50: Train loss 3.997, Learning Rate 1.000e-05, It/sec 0.305, Tokens/sec 122.895, Trained Tokens 24836\n",
"Iter 60: Train loss 3.450, Learning Rate 1.000e-05, It/sec 0.058, Tokens/sec 40.930, Trained Tokens 31839\n",
"Iter 70: Train loss 3.217, Learning Rate 1.000e-05, It/sec 0.379, Tokens/sec 136.167, Trained Tokens 35433\n",
"Iter 80: Train loss 3.088, Learning Rate 1.000e-05, It/sec 0.318, Tokens/sec 128.851, Trained Tokens 39491\n",
"Iter 90: Train loss 3.062, Learning Rate 1.000e-05, It/sec 0.308, Tokens/sec 130.390, Trained Tokens 43719\n",
"Iter 100: Train loss 2.816, Learning Rate 1.000e-05, It/sec 0.273, Tokens/sec 140.060, Trained Tokens 48851\n",
"Iter 100: Saved adapter weights to checkpoints/100_adapters.npz.\n",
"[WARNING] Some sequences are longer than 2048 tokens. The longest sentence 3640 will be truncated to 2048. Consider pre-splitting your data to save memory.\n",
"Iter 110: Train loss 2.900, Learning Rate 1.000e-05, It/sec 0.049, Tokens/sec 29.810, Trained Tokens 54916\n",
"Iter 120: Train loss 2.812, Learning Rate 1.000e-05, It/sec 0.268, Tokens/sec 117.730, Trained Tokens 59317\n",
"Iter 130: Train loss 2.858, Learning Rate 1.000e-05, It/sec 0.258, Tokens/sec 120.012, Trained Tokens 63962\n",
"Iter 140: Train loss 2.726, Learning Rate 1.000e-05, It/sec 0.285, Tokens/sec 122.977, Trained Tokens 68279\n",
"Iter 150: Train loss 2.836, Learning Rate 1.000e-05, It/sec 0.318, Tokens/sec 139.799, Trained Tokens 72677\n",
"Iter 160: Train loss 2.903, Learning Rate 1.000e-05, It/sec 0.394, Tokens/sec 139.106, Trained Tokens 76209\n",
"Iter 170: Train loss 2.735, Learning Rate 1.000e-05, It/sec 0.308, Tokens/sec 138.257, Trained Tokens 80694\n",
"Iter 180: Train loss 2.844, Learning Rate 1.000e-05, It/sec 0.224, Tokens/sec 107.364, Trained Tokens 85493\n",
"Iter 190: Train loss 2.613, Learning Rate 1.000e-05, It/sec 0.315, Tokens/sec 118.791, Trained Tokens 89263\n",
"Iter 200: Train loss 2.675, Learning Rate 1.000e-05, It/sec 0.308, Tokens/sec 126.078, Trained Tokens 93363\n",
"Iter 200: Val loss 2.770, Val took 49.972s\n",
"Iter 200: Saved adapter weights to checkpoints/200_adapters.npz.\n",
"Iter 210: Train loss 2.598, Learning Rate 1.000e-05, It/sec 0.350, Tokens/sec 120.461, Trained Tokens 96808\n",
"Iter 220: Train loss 2.691, Learning Rate 1.000e-05, It/sec 0.170, Tokens/sec 94.299, Trained Tokens 102342\n",
"Iter 230: Train loss 2.485, Learning Rate 1.000e-05, It/sec 0.278, Tokens/sec 110.115, Trained Tokens 106302\n",
"Iter 240: Train loss 2.565, Learning Rate 1.000e-05, It/sec 0.305, Tokens/sec 122.773, Trained Tokens 110332\n",
"Iter 250: Train loss 2.665, Learning Rate 1.000e-05, It/sec 0.308, Tokens/sec 131.526, Trained Tokens 114605\n",
"Iter 260: Train loss 2.566, Learning Rate 1.000e-05, It/sec 0.270, Tokens/sec 106.033, Trained Tokens 118531\n",
"Iter 270: Train loss 2.703, Learning Rate 1.000e-05, It/sec 0.271, Tokens/sec 109.787, Trained Tokens 122584\n",
"Iter 280: Train loss 2.818, Learning Rate 1.000e-05, It/sec 0.175, Tokens/sec 101.312, Trained Tokens 128358\n",
"Iter 290: Train loss 2.703, Learning Rate 1.000e-05, It/sec 0.314, Tokens/sec 121.601, Trained Tokens 132226\n",
"Iter 300: Train loss 2.659, Learning Rate 1.000e-05, It/sec 0.186, Tokens/sec 98.913, Trained Tokens 137555\n",
"Iter 300: Saved adapter weights to checkpoints/300_adapters.npz.\n",
"Iter 310: Train loss 2.579, Learning Rate 1.000e-05, It/sec 0.321, Tokens/sec 120.959, Trained Tokens 141326\n",
"Iter 320: Train loss 2.519, Learning Rate 1.000e-05, It/sec 0.323, Tokens/sec 117.220, Trained Tokens 144956\n",
"Iter 330: Train loss 2.692, Learning Rate 1.000e-05, It/sec 0.216, Tokens/sec 102.143, Trained Tokens 149685\n",
"Iter 340: Train loss 2.577, Learning Rate 1.000e-05, It/sec 0.331, Tokens/sec 109.429, Trained Tokens 152987\n",
"Iter 350: Train loss 2.519, Learning Rate 1.000e-05, It/sec 0.222, Tokens/sec 102.454, Trained Tokens 157592\n",
"Iter 360: Train loss 2.675, Learning Rate 1.000e-05, It/sec 0.286, Tokens/sec 115.840, Trained Tokens 161638\n",
"Iter 370: Train loss 2.527, Learning Rate 1.000e-05, It/sec 0.278, Tokens/sec 107.560, Trained Tokens 165501\n",
"Iter 380: Train loss 2.544, Learning Rate 1.000e-05, It/sec 0.298, Tokens/sec 108.646, Trained Tokens 169142\n",
"Iter 390: Train loss 2.599, Learning Rate 1.000e-05, It/sec 0.269, Tokens/sec 116.284, Trained Tokens 173457\n",
"Iter 400: Train loss 2.566, Learning Rate 1.000e-05, It/sec 0.268, Tokens/sec 117.563, Trained Tokens 177838\n",
"Iter 400: Val loss 2.636, Val took 56.804s\n",
"Iter 400: Saved adapter weights to checkpoints/400_adapters.npz.\n",
"Iter 410: Train loss 2.338, Learning Rate 1.000e-05, It/sec 0.238, Tokens/sec 97.225, Trained Tokens 181923\n",
"Iter 420: Train loss 2.602, Learning Rate 1.000e-05, It/sec 0.242, Tokens/sec 111.528, Trained Tokens 186539\n",
"Iter 430: Train loss 2.846, Learning Rate 1.000e-05, It/sec 0.227, Tokens/sec 104.883, Trained Tokens 191152\n",
"Iter 440: Train loss 2.562, Learning Rate 1.000e-05, It/sec 0.275, Tokens/sec 115.619, Trained Tokens 195359\n",
"Iter 450: Train loss 2.571, Learning Rate 1.000e-05, It/sec 0.245, Tokens/sec 112.251, Trained Tokens 199934\n",
"Iter 460: Train loss 2.730, Learning Rate 1.000e-05, It/sec 0.212, Tokens/sec 96.232, Trained Tokens 204475\n",
"Iter 470: Train loss 2.491, Learning Rate 1.000e-05, It/sec 0.284, Tokens/sec 103.841, Trained Tokens 208133\n",
"Iter 480: Train loss 2.579, Learning Rate 1.000e-05, It/sec 0.179, Tokens/sec 99.164, Trained Tokens 213682\n",
"Iter 490: Train loss 2.578, Learning Rate 1.000e-05, It/sec 0.259, Tokens/sec 107.788, Trained Tokens 217850\n",
"Iter 500: Train loss 2.411, Learning Rate 1.000e-05, It/sec 0.292, Tokens/sec 111.877, Trained Tokens 221676\n",
"Iter 500: Saved adapter weights to checkpoints/500_adapters.npz.\n",
"Iter 510: Train loss 2.354, Learning Rate 1.000e-05, It/sec 0.307, Tokens/sec 127.103, Trained Tokens 225819\n",
"Iter 520: Train loss 2.547, Learning Rate 1.000e-05, It/sec 0.263, Tokens/sec 110.252, Trained Tokens 230019\n",
"Iter 530: Train loss 2.600, Learning Rate 1.000e-05, It/sec 0.351, Tokens/sec 119.239, Trained Tokens 233417\n",
"Iter 540: Train loss 2.540, Learning Rate 1.000e-05, It/sec 0.313, Tokens/sec 126.378, Trained Tokens 237456\n",
"Iter 550: Train loss 2.464, Learning Rate 1.000e-05, It/sec 0.202, Tokens/sec 100.631, Trained Tokens 242442\n",
"Iter 560: Train loss 2.635, Learning Rate 1.000e-05, It/sec 0.277, Tokens/sec 121.453, Trained Tokens 246827\n",
"Iter 570: Train loss 2.459, Learning Rate 1.000e-05, It/sec 0.180, Tokens/sec 92.789, Trained Tokens 251991\n",
"Iter 580: Train loss 2.571, Learning Rate 1.000e-05, It/sec 0.210, Tokens/sec 101.168, Trained Tokens 256811\n",
"Iter 590: Train loss 2.459, Learning Rate 1.000e-05, It/sec 0.208, Tokens/sec 105.794, Trained Tokens 261886\n",
"Iter 600: Train loss 2.541, Learning Rate 1.000e-05, It/sec 0.278, Tokens/sec 116.746, Trained Tokens 266085\n",
"Iter 600: Val loss 2.578, Val took 54.380s\n",
"Iter 600: Saved adapter weights to checkpoints/600_adapters.npz.\n",
"Saved final adapter weights to adapters.npz.\n"
]
}
],
"source": [
"!python -m mlx_lm.lora \\\n",
" --model google/gemma-7b-it \\\n",
" --train \\\n",
" --iters 600 \\\n",
" --data data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"M3 Max 128GB で約 50 分かかりました。[npaka さんの記事](https://note.com/npaka/n/nc55e44e407ff?sub_rt=share_h)だと 20 分ほどで完了するそうなので、やはり NVIDIA A100 などの GPU と比べてしまうと時間がかかってしまいますね…。\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 推論の実行\n",
"\n",
"コマンドラインの場合は、下記の様なスクリプトで LoRA ウェイトも使った推論ができます。\n"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading pretrained model\n",
"Fetching 11 files: 100%|████████████████████| 11/11 [00:00<00:00, 169001.26it/s]\n",
"Total parameters 8539.516M\n",
"Trainable parameters 1.835M\n",
"Loading datasets\n",
"Generating\n",
"==========\n",
"Prompt: <bos><start_of_turn>user\\n空が青いのはなぜですか?<end_of_turn>\\n<start_of_turn>model\\n\n",
"我、りんえもんは思う。 空は青いのは、地球が太陽から遠く離れているためでござる。地球からの太陽光、距離によって色に変化し、それが人間の目にどのように映っているのかにきたりしますでござる。知らんけど。<start_of_turn>\n",
"==========\n",
"Prompt: 32.551 tokens-per-sec\n",
"Generation: 6.055 tokens-per-sec\n"
]
}
],
"source": [
"!python -m mlx_lm.generate --model \"google/gemma-7b-it\" \\\n",
" --adapter-file checkpoints/600_adapters.npz \\\n",
" --max-tokens 256 \\\n",
" --prompt \"<bos><start_of_turn>user\\n空が青いのはなぜですか?<end_of_turn>\\n<start_of_turn>model\\n\" \\\n",
" --seed 42"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading pretrained model\n",
"Fetching 11 files: 100%|█████████████████████| 11/11 [00:00<00:00, 91542.35it/s]\n",
"Total parameters 8539.516M\n",
"Trainable parameters 1.835M\n",
"Loading datasets\n",
"Generating\n",
"==========\n",
"Prompt: 空が青いのはなぜですか?\n",
"<end_of_turn>model\n",
"我、りんえもんは思う。 空が青い理由は、分光放射のだろう、放射光と反射光が反射光を遮断し、その反射光を他の色を遮断することで、空が青くなるというござる。 この反射光は、空気の粒子に衝突し、色を反射しますでござる。 空気の粒子に衝突し、その反射光が他の色を遮断することで、空が青くなる。知らんけど。<end_of_turn>\n",
"==========\n",
"Prompt: 12.347 tokens-per-sec\n",
"Generation: 6.067 tokens-per-sec\n"
]
}
],
"source": [
"!python -m mlx_lm.generate --model \"google/gemma-7b-it\" \\\n",
" --adapter-file checkpoints/600_adapters.npz \\\n",
" --max-tokens 256 \\\n",
" --prompt \"空が青いのはなぜですか?\" \\\n",
" # --prompt \"<bos><start_of_turn>user\\n空が青いのはなぜですか?<end_of_turn>\\n<start_of_turn>model\\n\" \\\n",
" --seed 69"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## LoRA ウェイトのマージ\n",
"\n",
"最後に、トレーニングした LoRA のウェイトをモデル自体にマージしてみます。\n",
"\n",
"下記のコマンドで用意されているオプションが表示されます。\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading pretrained model\n",
"usage: fuse.py [-h] [--model MODEL] [--save-path SAVE_PATH]\n",
" [--adapter-file ADAPTER_FILE] [--hf-path HF_PATH]\n",
" [--upload-repo UPLOAD_REPO] [--de-quantize]\n",
"\n",
"LoRA or QLoRA finetuning.\n",
"\n",
"options:\n",
" -h, --help show this help message and exit\n",
" --model MODEL The path to the local model directory or Hugging Face\n",
" repo.\n",
" --save-path SAVE_PATH\n",
" The path to save the fused model.\n",
" --adapter-file ADAPTER_FILE\n",
" Path to the trained adapter weights (npz or\n",
" safetensors).\n",
" --hf-path HF_PATH Path to the original Hugging Face model. Required for\n",
" upload if --model is a local directory.\n",
" --upload-repo UPLOAD_REPO\n",
" The Hugging Face repo to upload the model to.\n",
" --de-quantize Generate a de-quantized model.\n"
]
}
],
"source": [
"!python -m mlx_lm.fuse --help"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"早速マージしてみましょう。\n"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading pretrained model\n",
"Fetching 11 files: 100%|█████████████████████| 11/11 [00:00<00:00, 91542.35it/s]\n"
]
}
],
"source": [
"!python -m mlx_lm.fuse \\\n",
" --model google/gemma-7b-it \\\n",
" --adapter-file checkpoints/600_adapters.npz"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"無事マージが完了し、 `lora_fused_model` というディレクトリが作成されモデルの各種ファイルが保存されました。\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## マージされたモデルのロード\n",
"\n",
"念の為ロードもしてみます。\n"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"from mlx_lm import load\n",
"\n",
"fused_model, fused_tokenizer = load(\"./lora_fused_model/\")"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<bound method PreTrainedTokenizerBase.apply_chat_template of GemmaTokenizerFast(name_or_path='lora_fused_model', vocab_size=256000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='left', truncation_side='right', special_tokens={'bos_token': '<bos>', 'eos_token': '<eos>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<start_of_turn>', '<end_of_turn>']}, clean_up_tokenization_spaces=False), added_tokens_decoder={\n",
"\t0: AddedToken(\"<pad>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t1: AddedToken(\"<eos>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t2: AddedToken(\"<bos>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t3: AddedToken(\"<unk>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t106: AddedToken(\"<start_of_turn>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"\t107: AddedToken(\"<end_of_turn>\", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),\n",
"}>"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 当たり前ですが、tokenizerは変わっていないようです。\n",
"fused_tokenizer.apply_chat_template"
]
},
{
"cell_type": "code",
"execution_count": 62,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==========\n",
"Prompt: 空が青いのはなぜですか?\n",
"<end_of_turn>model\n",
"我、りんえもんは思う。 空は青色で、その理由は、空の圧力に由来する。空の圧力は、地球の圧力よりも低いので、空の圧力によって圧縮された空気は、地球の圧力によって圧縮された空気よりも圧縮され、その結果として、空気は地球の圧力によって圧縮された空気よりも疎化され、その結果として、空気は地球の圧力によって圧縮された空気よりも疎化され、その結果として、空気は地球の圧力によって圧縮された空気よりも疎化され、その結果として、空気は地球の圧力によって圧縮された空気よりも疎化され、その結果として、空気は地球の圧力によって圧縮された空気よりも疎化され、その結果として、空気は地球の圧力によって圧縮された空気よりも疎化され、その結果として、空気は地球の圧力によって圧縮された空気よりも疎化され、その結果として、空気は地球の圧力によって圧縮された空気よりも疎化され、その結果として、空気は地球の圧力によって圧縮された空気よりも疎化され、その結果として、空気は地球の圧力によって\n",
"==========\n",
"Prompt: 20.089 tokens-per-sec\n",
"Generation: 18.609 tokens-per-sec\n"
]
}
],
"source": [
"prompt = \"空が青いのはなぜですか?\"\n",
"response = generate(\n",
" fused_model,\n",
" fused_tokenizer,\n",
" prompt=prompt,\n",
" verbose=True, # Set to True to see the prompt and response\n",
" temp=0.0,\n",
" max_tokens=256,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"無事ロードができました。\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## おわりに\n",
"\n",
"以上、お読みいただきありがとうございます。少しでも参考になればと思います。\n",
"\n",
"もし似たようなコンテンツに興味があれば、フォローしていただけると嬉しいです:\n",
"\n",
"- [note](https://note.com/alexweberk/) と\n",
"- [Twitter](https://twitter.com/alexweberk)\n",
"\n",
"https://twitter.com/alexweberk\n",
"\n",
"今回使った Notebook の Gist:\n",
"https://gist.github.com/alexweberk/1434c95c05463866491677aac6ce19ba\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 参考\n",
"\n",
"- [MLX Examples](https://github.com/ml-explore/mlx-examples/tree/main/llms/mlx_lm)\n",
"- [Peft を使った Gemma のファインチューニング](https://huggingface.co/blog/gemma-peft)\n",
"\n",
"* https://gist.github.com/alfredplpl/e20cad036c151f38645a1abc87f56a2f\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "py311",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment