Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save buttercutter/1593ed1ae13e56b50c05f1d60c296204 to your computer and use it in GitHub Desktop.
Save buttercutter/1593ed1ae13e56b50c05f1d60c296204 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "PiSi90gspEQP"
},
"source": [
"# Easy GPT-Q + LoRA in JAX ([github](https://github.com/davisyoshida/easy-lora-and-gptq))\n",
"\n",
"[Davis Yoshida](https://github.com/davisyoshida/)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hfxALa1so2JD"
},
"source": [
"This notebook shows how to combine two JAX tools/transforms I wrote: [Lorax](https://github.com/davisyoshida/lorax) and [JAX-GPTQ](https://github.com/davisyoshida/jax-gptq). I've been using the combination to run LLaMA finetunes on a single GPU.\n",
"\n",
"They're both applicable to basically any JAX function, which conveniently includes many HuggingFace models!\n",
"\n",
"The procedure is as follows:\n",
"\n",
"1. Quantize the weights of the model we want to use\n",
"2. Use Lorax to transform the original model function `F(params, inputs)` to one that takes a tuple of the original params and the low rank LoRA params: `F_lora(param_tuple, inputs)`\n",
"3. Wrap `F_lora` in `use_quantized` transform so that it knows how to handle arguments which are int8 matrices with two parameters per byte.\n",
"4. Train the model, updating only the low rank params and leaving the larger 4-bit model weights frozen.\n",
"\n",
"I'd love feedback on one or both of these tools so please let me know on their Githubs if you have any suggestions. JAX-GPTQ in particular is still in a really early state."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0Y6JeyF45yd_"
},
"source": [
"### Setup"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true,
"id": "ljjNpQvkrhsA"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: jax-lorax in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (0.1.2)\n",
"Requirement already satisfied: jax<0.5.0,>=0.4.6 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from jax-lorax) (0.4.12)\n",
"Requirement already satisfied: jaxlib<0.5.0,>=0.4.6 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from jax-lorax) (0.4.12)\n",
"Requirement already satisfied: ml-dtypes>=0.1.0 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from jax<0.5.0,>=0.4.6->jax-lorax) (0.1.0)\n",
"Requirement already satisfied: numpy>=1.21 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from jax<0.5.0,>=0.4.6->jax-lorax) (1.24.3)\n",
"Requirement already satisfied: opt-einsum in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from jax<0.5.0,>=0.4.6->jax-lorax) (3.3.0)\n",
"Requirement already satisfied: scipy>=1.7 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from jax<0.5.0,>=0.4.6->jax-lorax) (1.10.1)\n",
"Requirement already satisfied: transformers in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (4.30.1)\n",
"Requirement already satisfied: filelock in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from transformers) (3.12.0)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from transformers) (0.14.1)\n",
"Requirement already satisfied: numpy>=1.17 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from transformers) (1.24.3)\n",
"Requirement already satisfied: packaging>=20.0 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from transformers) (23.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from transformers) (6.0)\n",
"Requirement already satisfied: regex!=2019.12.17 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from transformers) (2023.5.5)\n",
"Requirement already satisfied: requests in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from transformers) (2.31.0)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from transformers) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from transformers) (0.3.1)\n",
"Requirement already satisfied: tqdm>=4.27 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from transformers) (4.65.0)\n",
"Requirement already satisfied: fsspec in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.5.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.6.2)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from requests->transformers) (3.1.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from requests->transformers) (3.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from requests->transformers) (2.0.2)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from requests->transformers) (2023.5.7)\n",
"Requirement already satisfied: optax in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (0.1.5)\n",
"Requirement already satisfied: absl-py>=0.7.1 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from optax) (1.4.0)\n",
"Requirement already satisfied: chex>=0.1.5 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from optax) (0.1.7)\n",
"Requirement already satisfied: jax>=0.1.55 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from optax) (0.4.12)\n",
"Requirement already satisfied: jaxlib>=0.1.37 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from optax) (0.4.12)\n",
"Requirement already satisfied: numpy>=1.18.0 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from optax) (1.24.3)\n",
"Requirement already satisfied: dm-tree>=0.1.5 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from chex>=0.1.5->optax) (0.1.8)\n",
"Requirement already satisfied: toolz>=0.9.0 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from chex>=0.1.5->optax) (0.12.0)\n",
"Requirement already satisfied: ml-dtypes>=0.1.0 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from jax>=0.1.55->optax) (0.1.0)\n",
"Requirement already satisfied: opt-einsum in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from jax>=0.1.55->optax) (3.3.0)\n",
"Requirement already satisfied: scipy>=1.7 in /home/moe/anaconda3/envs/arc/lib/python3.11/site-packages (from jax>=0.1.55->optax) (1.10.1)\n"
]
}
],
"source": [
"#!pip install git+https://github.com/davisyoshida/jax-gptq.git # modify to have TRC tpu support, see the modification made to quantize_interpreter.py \n",
"!pip install jax-lorax\n",
"!pip install transformers\n",
"!pip install optax"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "75-T_R0Ms9qD"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/moe/anaconda3/envs/arc/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
}
],
"source": [
"from functools import partial\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"import optax\n",
"import transformers\n",
"from tqdm import trange\n",
"\n",
"import lorax\n",
"import jax_gptq\n",
"\n",
"tpu = jax.devices()\n",
"#gpu = jax.devices('gpu')[0]\n",
"#cpu = jax.devices('cpu')[0]\n",
"\n",
"device = tpu"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "GQuDSjz7svdL"
},
"source": [
"## Toy Example\n",
"\n",
"### Model/Data setup"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████| 2/2 [04:07<00:00, 123.52s/it]\n"
]
}
],
"source": [
"from transformers import LongT5Config, FlaxT5ForConditionalGeneration\n",
"from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer\n",
"\n",
"'''\n",
"from transformers import BitsAndBytesConfig\n",
"\n",
"nf4_config = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_quant_type=\"nf4\",\n",
" bnb_4bit_use_double_quant=True,\n",
" bnb_4bit_compute_dtype=torch.bfloat16\n",
")\n",
"'''\n",
"\n",
"# Load the LongT5-XL model with its configuration\n",
"#model_id = \"google/long-t5-tglobal-xl\"\n",
"model_id = \"/home/moe/checkpoints/\" # quantized version of longT5-XL model\n",
"config = LongT5Config.from_pretrained(model_id)\n",
"#model = AutoModelForSeq2SeqLM.from_pretrained(model_id, load_in_4bit=True, device_map=\"auto\")\n",
"#model = AutoModelForSeq2SeqLM.from_pretrained(model_id, quantization_config=nf4_config)\n",
"model = AutoModelForSeq2SeqLM.from_pretrained(model_id)\n",
"tokenizer = AutoTokenizer.from_pretrained(\"google/long-t5-tglobal-xl\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "Djyo_reAs26R"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using LoRA with dim=32 for param 0\n",
"Using LoRA with dim=32 for param 1\n",
"Using LoRA with dim=32 for param 2\n",
"Using LoRA with dim=32 for param 3\n",
"Using LoRA with dim=32 for param 4\n",
"Using LoRA with dim=32 for param 5\n",
"Using LoRA with dim=32 for param 6\n",
"Using LoRA with dim=32 for param 9\n",
"Using LoRA with dim=32 for param 10\n",
"Using LoRA with dim=32 for param 11\n",
"Using LoRA with dim=32 for param 13\n",
"Using LoRA with dim=32 for param 14\n",
"Using LoRA with dim=32 for param 15\n",
"Using LoRA with dim=32 for param 16\n",
"Using LoRA with dim=32 for param 19\n",
"Using LoRA with dim=32 for param 20\n",
"Using LoRA with dim=32 for param 21\n",
"Using LoRA with dim=32 for param 23\n",
"Using LoRA with dim=32 for param 24\n",
"Using LoRA with dim=32 for param 25\n",
"Using LoRA with dim=32 for param 26\n",
"Using LoRA with dim=32 for param 29\n",
"Using LoRA with dim=32 for param 30\n",
"Using LoRA with dim=32 for param 31\n",
"Using LoRA with dim=32 for param 33\n",
"Using LoRA with dim=32 for param 34\n",
"Using LoRA with dim=32 for param 35\n",
"Using LoRA with dim=32 for param 36\n",
"Using LoRA with dim=32 for param 39\n",
"Using LoRA with dim=32 for param 40\n",
"Using LoRA with dim=32 for param 41\n",
"Using LoRA with dim=32 for param 43\n",
"Using LoRA with dim=32 for param 44\n",
"Using LoRA with dim=32 for param 45\n",
"Using LoRA with dim=32 for param 46\n",
"Using LoRA with dim=32 for param 49\n",
"Using LoRA with dim=32 for param 50\n",
"Using LoRA with dim=32 for param 51\n",
"Using LoRA with dim=32 for param 53\n",
"Using LoRA with dim=32 for param 54\n",
"Using LoRA with dim=32 for param 55\n",
"Using LoRA with dim=32 for param 56\n",
"Using LoRA with dim=32 for param 59\n",
"Using LoRA with dim=32 for param 60\n",
"Using LoRA with dim=32 for param 61\n",
"Using LoRA with dim=32 for param 63\n",
"Using LoRA with dim=32 for param 64\n",
"Using LoRA with dim=32 for param 65\n",
"Using LoRA with dim=32 for param 66\n",
"Using LoRA with dim=32 for param 69\n",
"Using LoRA with dim=32 for param 70\n",
"Using LoRA with dim=32 for param 71\n",
"Using LoRA with dim=32 for param 73\n",
"Using LoRA with dim=32 for param 74\n",
"Using LoRA with dim=32 for param 75\n",
"Using LoRA with dim=32 for param 76\n",
"Using LoRA with dim=32 for param 79\n",
"Using LoRA with dim=32 for param 80\n",
"Using LoRA with dim=32 for param 81\n",
"Using LoRA with dim=32 for param 83\n",
"Using LoRA with dim=32 for param 84\n",
"Using LoRA with dim=32 for param 85\n",
"Using LoRA with dim=32 for param 86\n",
"Using LoRA with dim=32 for param 89\n",
"Using LoRA with dim=32 for param 90\n",
"Using LoRA with dim=32 for param 91\n",
"Using LoRA with dim=32 for param 93\n",
"Using LoRA with dim=32 for param 94\n",
"Using LoRA with dim=32 for param 95\n",
"Using LoRA with dim=32 for param 96\n",
"Using LoRA with dim=32 for param 99\n",
"Using LoRA with dim=32 for param 100\n",
"Using LoRA with dim=32 for param 101\n",
"Using LoRA with dim=32 for param 103\n",
"Using LoRA with dim=32 for param 104\n",
"Using LoRA with dim=32 for param 105\n",
"Using LoRA with dim=32 for param 106\n",
"Using LoRA with dim=32 for param 109\n",
"Using LoRA with dim=32 for param 110\n",
"Using LoRA with dim=32 for param 111\n",
"Using LoRA with dim=32 for param 113\n",
"Using LoRA with dim=32 for param 114\n",
"Using LoRA with dim=32 for param 115\n",
"Using LoRA with dim=32 for param 116\n",
"Using LoRA with dim=32 for param 119\n",
"Using LoRA with dim=32 for param 120\n",
"Using LoRA with dim=32 for param 121\n",
"Using LoRA with dim=32 for param 123\n",
"Using LoRA with dim=32 for param 124\n",
"Using LoRA with dim=32 for param 125\n",
"Using LoRA with dim=32 for param 126\n",
"Using LoRA with dim=32 for param 129\n",
"Using LoRA with dim=32 for param 130\n",
"Using LoRA with dim=32 for param 131\n",
"Using LoRA with dim=32 for param 133\n",
"Using LoRA with dim=32 for param 134\n",
"Using LoRA with dim=32 for param 135\n",
"Using LoRA with dim=32 for param 136\n",
"Using LoRA with dim=32 for param 139\n",
"Using LoRA with dim=32 for param 140\n",
"Using LoRA with dim=32 for param 141\n",
"Using LoRA with dim=32 for param 143\n",
"Using LoRA with dim=32 for param 144\n",
"Using LoRA with dim=32 for param 145\n",
"Using LoRA with dim=32 for param 146\n",
"Using LoRA with dim=32 for param 149\n",
"Using LoRA with dim=32 for param 150\n",
"Using LoRA with dim=32 for param 151\n",
"Using LoRA with dim=32 for param 153\n",
"Using LoRA with dim=32 for param 154\n",
"Using LoRA with dim=32 for param 155\n",
"Using LoRA with dim=32 for param 156\n",
"Using LoRA with dim=32 for param 159\n",
"Using LoRA with dim=32 for param 160\n",
"Using LoRA with dim=32 for param 161\n",
"Using LoRA with dim=32 for param 163\n",
"Using LoRA with dim=32 for param 164\n",
"Using LoRA with dim=32 for param 165\n",
"Using LoRA with dim=32 for param 166\n",
"Using LoRA with dim=32 for param 169\n",
"Using LoRA with dim=32 for param 170\n",
"Using LoRA with dim=32 for param 171\n",
"Using LoRA with dim=32 for param 173\n",
"Using LoRA with dim=32 for param 174\n",
"Using LoRA with dim=32 for param 175\n",
"Using LoRA with dim=32 for param 176\n",
"Using LoRA with dim=32 for param 179\n",
"Using LoRA with dim=32 for param 180\n",
"Using LoRA with dim=32 for param 181\n",
"Using LoRA with dim=32 for param 183\n",
"Using LoRA with dim=32 for param 184\n",
"Using LoRA with dim=32 for param 185\n",
"Using LoRA with dim=32 for param 186\n",
"Using LoRA with dim=32 for param 189\n",
"Using LoRA with dim=32 for param 190\n",
"Using LoRA with dim=32 for param 191\n",
"Using LoRA with dim=32 for param 193\n",
"Using LoRA with dim=32 for param 194\n",
"Using LoRA with dim=32 for param 195\n",
"Using LoRA with dim=32 for param 196\n",
"Using LoRA with dim=32 for param 199\n",
"Using LoRA with dim=32 for param 200\n",
"Using LoRA with dim=32 for param 201\n",
"Using LoRA with dim=32 for param 203\n",
"Using LoRA with dim=32 for param 204\n",
"Using LoRA with dim=32 for param 205\n",
"Using LoRA with dim=32 for param 206\n",
"Using LoRA with dim=32 for param 209\n",
"Using LoRA with dim=32 for param 210\n",
"Using LoRA with dim=32 for param 211\n",
"Using LoRA with dim=32 for param 213\n",
"Using LoRA with dim=32 for param 214\n",
"Using LoRA with dim=32 for param 215\n",
"Using LoRA with dim=32 for param 216\n",
"Using LoRA with dim=32 for param 219\n",
"Using LoRA with dim=32 for param 220\n",
"Using LoRA with dim=32 for param 221\n",
"Using LoRA with dim=32 for param 223\n",
"Using LoRA with dim=32 for param 224\n",
"Using LoRA with dim=32 for param 225\n",
"Using LoRA with dim=32 for param 226\n",
"Using LoRA with dim=32 for param 229\n",
"Using LoRA with dim=32 for param 230\n",
"Using LoRA with dim=32 for param 231\n",
"Using LoRA with dim=32 for param 233\n",
"Using LoRA with dim=32 for param 234\n",
"Using LoRA with dim=32 for param 235\n",
"Using LoRA with dim=32 for param 236\n",
"Using LoRA with dim=32 for param 239\n",
"Using LoRA with dim=32 for param 240\n",
"Using LoRA with dim=32 for param 241\n",
"Using LoRA with dim=32 for param 244\n",
"Using LoRA with dim=32 for param 245\n",
"Using LoRA with dim=32 for param 246\n",
"Using LoRA with dim=32 for param 247\n",
"Using LoRA with dim=32 for param 248\n",
"Using LoRA with dim=32 for param 250\n",
"Using LoRA with dim=32 for param 251\n",
"Using LoRA with dim=32 for param 252\n",
"Using LoRA with dim=32 for param 253\n",
"Using LoRA with dim=32 for param 255\n",
"Using LoRA with dim=32 for param 256\n",
"Using LoRA with dim=32 for param 257\n",
"Using LoRA with dim=32 for param 259\n",
"Using LoRA with dim=32 for param 260\n",
"Using LoRA with dim=32 for param 261\n",
"Using LoRA with dim=32 for param 262\n",
"Using LoRA with dim=32 for param 264\n",
"Using LoRA with dim=32 for param 265\n",
"Using LoRA with dim=32 for param 266\n",
"Using LoRA with dim=32 for param 267\n",
"Using LoRA with dim=32 for param 269\n",
"Using LoRA with dim=32 for param 270\n",
"Using LoRA with dim=32 for param 271\n",
"Using LoRA with dim=32 for param 273\n",
"Using LoRA with dim=32 for param 274\n",
"Using LoRA with dim=32 for param 275\n",
"Using LoRA with dim=32 for param 276\n",
"Using LoRA with dim=32 for param 278\n",
"Using LoRA with dim=32 for param 279\n",
"Using LoRA with dim=32 for param 280\n",
"Using LoRA with dim=32 for param 281\n",
"Using LoRA with dim=32 for param 283\n",
"Using LoRA with dim=32 for param 284\n",
"Using LoRA with dim=32 for param 285\n",
"Using LoRA with dim=32 for param 287\n",
"Using LoRA with dim=32 for param 288\n",
"Using LoRA with dim=32 for param 289\n",
"Using LoRA with dim=32 for param 290\n",
"Using LoRA with dim=32 for param 292\n",
"Using LoRA with dim=32 for param 293\n",
"Using LoRA with dim=32 for param 294\n",
"Using LoRA with dim=32 for param 295\n",
"Using LoRA with dim=32 for param 297\n",
"Using LoRA with dim=32 for param 298\n",
"Using LoRA with dim=32 for param 299\n",
"Using LoRA with dim=32 for param 301\n",
"Using LoRA with dim=32 for param 302\n",
"Using LoRA with dim=32 for param 303\n",
"Using LoRA with dim=32 for param 304\n",
"Using LoRA with dim=32 for param 306\n",
"Using LoRA with dim=32 for param 307\n",
"Using LoRA with dim=32 for param 308\n",
"Using LoRA with dim=32 for param 309\n",
"Using LoRA with dim=32 for param 311\n",
"Using LoRA with dim=32 for param 312\n",
"Using LoRA with dim=32 for param 313\n",
"Using LoRA with dim=32 for param 315\n",
"Using LoRA with dim=32 for param 316\n",
"Using LoRA with dim=32 for param 317\n",
"Using LoRA with dim=32 for param 318\n",
"Using LoRA with dim=32 for param 320\n",
"Using LoRA with dim=32 for param 321\n",
"Using LoRA with dim=32 for param 322\n",
"Using LoRA with dim=32 for param 323\n",
"Using LoRA with dim=32 for param 325\n",
"Using LoRA with dim=32 for param 326\n",
"Using LoRA with dim=32 for param 327\n",
"Using LoRA with dim=32 for param 329\n",
"Using LoRA with dim=32 for param 330\n",
"Using LoRA with dim=32 for param 331\n",
"Using LoRA with dim=32 for param 332\n",
"Using LoRA with dim=32 for param 334\n",
"Using LoRA with dim=32 for param 335\n",
"Using LoRA with dim=32 for param 336\n",
"Using LoRA with dim=32 for param 337\n",
"Using LoRA with dim=32 for param 339\n",
"Using LoRA with dim=32 for param 340\n",
"Using LoRA with dim=32 for param 341\n",
"Using LoRA with dim=32 for param 343\n",
"Using LoRA with dim=32 for param 344\n",
"Using LoRA with dim=32 for param 345\n",
"Using LoRA with dim=32 for param 346\n",
"Using LoRA with dim=32 for param 348\n",
"Using LoRA with dim=32 for param 349\n",
"Using LoRA with dim=32 for param 350\n",
"Using LoRA with dim=32 for param 351\n",
"Using LoRA with dim=32 for param 353\n",
"Using LoRA with dim=32 for param 354\n",
"Using LoRA with dim=32 for param 355\n",
"Using LoRA with dim=32 for param 357\n",
"Using LoRA with dim=32 for param 358\n",
"Using LoRA with dim=32 for param 359\n",
"Using LoRA with dim=32 for param 360\n",
"Using LoRA with dim=32 for param 362\n",
"Using LoRA with dim=32 for param 363\n",
"Using LoRA with dim=32 for param 364\n",
"Using LoRA with dim=32 for param 365\n",
"Using LoRA with dim=32 for param 367\n",
"Using LoRA with dim=32 for param 368\n",
"Using LoRA with dim=32 for param 369\n",
"Using LoRA with dim=32 for param 371\n",
"Using LoRA with dim=32 for param 372\n",
"Using LoRA with dim=32 for param 373\n",
"Using LoRA with dim=32 for param 374\n",
"Using LoRA with dim=32 for param 376\n",
"Using LoRA with dim=32 for param 377\n",
"Using LoRA with dim=32 for param 378\n",
"Using LoRA with dim=32 for param 379\n",
"Using LoRA with dim=32 for param 381\n",
"Using LoRA with dim=32 for param 382\n",
"Using LoRA with dim=32 for param 383\n",
"Using LoRA with dim=32 for param 385\n",
"Using LoRA with dim=32 for param 386\n",
"Using LoRA with dim=32 for param 387\n",
"Using LoRA with dim=32 for param 388\n",
"Using LoRA with dim=32 for param 390\n",
"Using LoRA with dim=32 for param 391\n",
"Using LoRA with dim=32 for param 392\n",
"Using LoRA with dim=32 for param 393\n",
"Using LoRA with dim=32 for param 395\n",
"Using LoRA with dim=32 for param 396\n",
"Using LoRA with dim=32 for param 397\n",
"Using LoRA with dim=32 for param 399\n",
"Using LoRA with dim=32 for param 400\n",
"Using LoRA with dim=32 for param 401\n",
"Using LoRA with dim=32 for param 402\n",
"Using LoRA with dim=32 for param 404\n",
"Using LoRA with dim=32 for param 405\n",
"Using LoRA with dim=32 for param 406\n",
"Using LoRA with dim=32 for param 407\n",
"Using LoRA with dim=32 for param 409\n",
"Using LoRA with dim=32 for param 410\n",
"Using LoRA with dim=32 for param 411\n",
"Using LoRA with dim=32 for param 413\n",
"Using LoRA with dim=32 for param 414\n",
"Using LoRA with dim=32 for param 415\n",
"Using LoRA with dim=32 for param 416\n",
"Using LoRA with dim=32 for param 418\n",
"Using LoRA with dim=32 for param 419\n",
"Using LoRA with dim=32 for param 420\n",
"Using LoRA with dim=32 for param 421\n",
"Using LoRA with dim=32 for param 423\n",
"Using LoRA with dim=32 for param 424\n",
"Using LoRA with dim=32 for param 425\n",
"Using LoRA with dim=32 for param 427\n",
"Using LoRA with dim=32 for param 428\n",
"Using LoRA with dim=32 for param 429\n",
"Using LoRA with dim=32 for param 430\n",
"Using LoRA with dim=32 for param 432\n",
"Using LoRA with dim=32 for param 433\n",
"Using LoRA with dim=32 for param 434\n",
"Using LoRA with dim=32 for param 435\n",
"Using LoRA with dim=32 for param 437\n",
"Using LoRA with dim=32 for param 438\n",
"Using LoRA with dim=32 for param 439\n",
"Using LoRA with dim=32 for param 441\n",
"Using LoRA with dim=32 for param 442\n",
"Using LoRA with dim=32 for param 443\n",
"Using LoRA with dim=32 for param 444\n",
"Using LoRA with dim=32 for param 446\n",
"Using LoRA with dim=32 for param 447\n",
"Using LoRA with dim=32 for param 448\n",
"Using LoRA with dim=32 for param 449\n",
"Using LoRA with dim=32 for param 451\n",
"Using LoRA with dim=32 for param 452\n",
"Using LoRA with dim=32 for param 453\n",
"Using LoRA with dim=32 for param 455\n",
"Using LoRA with dim=32 for param 456\n",
"Using LoRA with dim=32 for param 457\n",
"Using LoRA with dim=32 for param 458\n",
"Using LoRA with dim=32 for param 460\n",
"Using LoRA with dim=32 for param 461\n",
"Using LoRA with dim=32 for param 462\n",
"Using LoRA with dim=32 for param 463\n",
"Using LoRA with dim=32 for param 465\n",
"Using LoRA with dim=32 for param 466\n",
"Using LoRA with dim=32 for param 467\n",
"Using LoRA with dim=32 for param 469\n",
"Using LoRA with dim=32 for param 470\n",
"Using LoRA with dim=32 for param 471\n",
"Using LoRA with dim=32 for param 472\n",
"Using LoRA with dim=32 for param 474\n",
"Using LoRA with dim=32 for param 475\n",
"Using LoRA with dim=32 for param 476\n",
"Using LoRA with dim=32 for param 477\n",
"Using LoRA with dim=32 for param 479\n",
"Using LoRA with dim=32 for param 480\n",
"Using LoRA with dim=32 for param 481\n",
"Using LoRA with dim=32 for param 483\n",
"Using LoRA with dim=32 for param 484\n",
"Using LoRA with dim=32 for param 485\n",
"Using LoRA with dim=32 for param 486\n",
"Using LoRA with dim=32 for param 488\n",
"Using LoRA with dim=32 for param 489\n",
"Using LoRA with dim=32 for param 490\n",
"Using LoRA with dim=32 for param 491\n",
"Using LoRA with dim=32 for param 493\n",
"Using LoRA with dim=32 for param 494\n",
"Using LoRA with dim=32 for param 495\n",
"Using LoRA with dim=32 for param 497\n",
"Using LoRA with dim=32 for param 498\n",
"Using LoRA with dim=32 for param 499\n",
"Using LoRA with dim=32 for param 500\n",
"Using LoRA with dim=32 for param 502\n",
"Using LoRA with dim=32 for param 503\n",
"Using LoRA with dim=32 for param 504\n",
"Using LoRA with dim=32 for param 505\n",
"Using LoRA with dim=32 for param 507\n",
"Using LoRA with dim=32 for param 508\n",
"Using LoRA with dim=32 for param 509\n",
"Using LoRA with dim=32 for param 511\n",
"Using LoRA with dim=32 for param 512\n",
"Using LoRA with dim=32 for param 513\n",
"Using LoRA with dim=32 for param 514\n",
"Using LoRA with dim=32 for param 516\n",
"Using LoRA with dim=32 for param 517\n",
"Using LoRA with dim=32 for param 518\n",
"Using LoRA with dim=32 for param 519\n",
"Using LoRA with dim=32 for param 521\n",
"Using LoRA with dim=32 for param 522\n",
"Using LoRA with dim=32 for param 523\n",
"Using LoRA with dim=32 for param 525\n",
"Using LoRA with dim=32 for param 526\n",
"Using LoRA with dim=32 for param 527\n",
"Using LoRA with dim=32 for param 528\n",
"Using LoRA with dim=32 for param 530\n",
"Using LoRA with dim=32 for param 531\n",
"Using LoRA with dim=32 for param 532\n",
"Using LoRA with dim=32 for param 533\n",
"Using LoRA with dim=32 for param 535\n",
"Using LoRA with dim=32 for param 536\n",
"Using LoRA with dim=32 for param 537\n",
"Using LoRA with dim=32 for param 539\n",
"Using LoRA with dim=32 for param 540\n",
"Using LoRA with dim=32 for param 541\n",
"Using LoRA with dim=32 for param 542\n",
"Using LoRA with dim=32 for param 544\n",
"Using LoRA with dim=32 for param 545\n",
"Using LoRA with dim=32 for param 546\n",
"Using LoRA with dim=32 for param 547\n",
"Using LoRA with dim=32 for param 549\n",
"Using LoRA with dim=32 for param 550\n",
"Using LoRA with dim=32 for param 551\n",
"Using LoRA with dim=32 for param 553\n",
"Using LoRA with dim=32 for param 554\n",
"Using LoRA with dim=32 for param 555\n",
"Using LoRA with dim=32 for param 556\n",
"Using LoRA with dim=32 for param 558\n",
"Using LoRA with dim=32 for param 559\n",
"Using LoRA with dim=32 for param 560\n",
"Using LoRA with dim=32 for param 561\n",
"Using LoRA with dim=32 for param 563\n",
"Using LoRA with dim=32 for param 564\n",
"Using LoRA with dim=32 for param 565\n",
"Using LoRA with dim=32 for param 567\n",
"Using LoRA with dim=32 for param 568\n",
"Using LoRA with dim=32 for param 569\n",
"Using LoRA with dim=32 for param 570\n",
"Using LoRA with dim=32 for param 572\n",
"Using LoRA with dim=32 for param 573\n",
"Using LoRA with dim=32 for param 574\n",
"Using LoRA with dim=32 for param 575\n",
"Using LoRA with dim=32 for param 577\n",
"Using LoRA with dim=32 for param 578\n",
"Using LoRA with dim=32 for param 579\n",
"Using LoRA with dim=32 for param 582\n"
]
},
{
"ename": "TypeError",
"evalue": "'int' object is not callable",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[8], line 84\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39mMax prediction gap: \u001b[39m\u001b[39m{\u001b[39;00mgap\u001b[39m:\u001b[39;00m\u001b[39m.3e\u001b[39m\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m)\n\u001b[1;32m 83\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39m__name__\u001b[39m \u001b[39m==\u001b[39m \u001b[39m'\u001b[39m\u001b[39m__main__\u001b[39m\u001b[39m'\u001b[39m:\n\u001b[0;32m---> 84\u001b[0m main()\n",
"Cell \u001b[0;32mIn[8], line 72\u001b[0m, in \u001b[0;36mmain\u001b[0;34m()\u001b[0m\n\u001b[1;32m 69\u001b[0m example_data \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39mrandom\u001b[39m.\u001b[39mnormal(jax\u001b[39m.\u001b[39mrandom\u001b[39m.\u001b[39mPRNGKey(\u001b[39m0\u001b[39m), (\u001b[39m4\u001b[39m, \u001b[39m128\u001b[39m, model\u001b[39m.\u001b[39mconfig\u001b[39m.\u001b[39mhidden_size)) \n\u001b[1;32m 71\u001b[0m \u001b[39mfor\u001b[39;00m _ \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(\u001b[39m100\u001b[39m):\n\u001b[0;32m---> 72\u001b[0m tune_params, opt_state, loss \u001b[39m=\u001b[39m update_fn(tune_params, freeze_params, opt_state, example_data)\n\u001b[1;32m 73\u001b[0m \u001b[39mprint\u001b[39m(loss)\n\u001b[1;32m 75\u001b[0m final_predictions \u001b[39m=\u001b[39m lora_forward((freeze_params, tune_params), example_data)\u001b[39m.\u001b[39mlogits\n",
" \u001b[0;31m[... skipping hidden 12 frame]\u001b[0m\n",
"Cell \u001b[0;32mIn[8], line 61\u001b[0m, in \u001b[0;36mmain.<locals>.update_fn\u001b[0;34m(tunable_params, frozen_params, opt_state, batch)\u001b[0m\n\u001b[1;32m 59\u001b[0m \u001b[39m@jax\u001b[39m\u001b[39m.\u001b[39mjit\n\u001b[1;32m 60\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mupdate_fn\u001b[39m(tunable_params, frozen_params, opt_state, batch):\n\u001b[0;32m---> 61\u001b[0m loss, grads \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39;49mvalue_and_grad(loss_fn)(tunable_params, frozen_params, batch)\n\u001b[1;32m 62\u001b[0m updates, new_opt_state \u001b[39m=\u001b[39m optimizer\u001b[39m.\u001b[39mupdate(grads, opt_state, params\u001b[39m=\u001b[39mtunable_params)\n\u001b[1;32m 64\u001b[0m new_tunable_params \u001b[39m=\u001b[39m optax\u001b[39m.\u001b[39mapply_updates(tunable_params, updates)\n",
" \u001b[0;31m[... skipping hidden 8 frame]\u001b[0m\n",
"Cell \u001b[0;32mIn[8], line 53\u001b[0m, in \u001b[0;36mmain.<locals>.loss_fn\u001b[0;34m(tunable_params, frozen_params, batch)\u001b[0m\n\u001b[1;32m 51\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mloss_fn\u001b[39m(tunable_params, frozen_params, batch):\n\u001b[1;32m 52\u001b[0m input_ids \u001b[39m=\u001b[39m batch[:, :\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m]\n\u001b[0;32m---> 53\u001b[0m logits \u001b[39m=\u001b[39m lora_forward((frozen_params, tunable_params), input_ids)\u001b[39m.\u001b[39mlogits\n\u001b[1;32m 55\u001b[0m logprobs \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39mnn\u001b[39m.\u001b[39mlog_softmax(logits)\n\u001b[1;32m 56\u001b[0m target_logprobs \u001b[39m=\u001b[39m jnp\u001b[39m.\u001b[39mtake_along_axis(logprobs, batch[:, \u001b[39m1\u001b[39m:, \u001b[39mNone\u001b[39;00m], axis\u001b[39m=\u001b[39m\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m)\n",
"File \u001b[0;32m~/anaconda3/envs/arc/lib/python3.11/site-packages/lorax/transform.py:78\u001b[0m, in \u001b[0;36mlora.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39many\u001b[39m(node \u001b[39mis\u001b[39;00m EmptyNode \u001b[39mfor\u001b[39;00m node \u001b[39min\u001b[39;00m custom_tree_leaves(orig_args[argnum]))\n\u001b[1;32m 74\u001b[0m shape_args, shape_kwargs \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39mtree_map(\n\u001b[1;32m 75\u001b[0m \u001b[39mlambda\u001b[39;00m x: jax\u001b[39m.\u001b[39mcore\u001b[39m.\u001b[39mget_aval(x) \u001b[39mif\u001b[39;00m \u001b[39misinstance\u001b[39m(x, jax\u001b[39m.\u001b[39mcore\u001b[39m.\u001b[39mTracer) \u001b[39melse\u001b[39;00m x, \n\u001b[1;32m 76\u001b[0m (orig_args, kwargs)\n\u001b[1;32m 77\u001b[0m )\n\u001b[0;32m---> 78\u001b[0m closed_jaxpr \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39;49mmake_jaxpr(f)(\u001b[39m*\u001b[39;49mshape_args, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mshape_kwargs)\n\u001b[1;32m 79\u001b[0m out_shape \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39meval_shape(f, \u001b[39m*\u001b[39mshape_args, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mshape_kwargs)\n\u001b[1;32m 80\u001b[0m out_structure \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39mtree_util\u001b[39m.\u001b[39mtree_structure(out_shape)\n",
" \u001b[0;31m[... skipping hidden 6 frame]\u001b[0m\n",
"Cell \u001b[0;32mIn[8], line 21\u001b[0m, in \u001b[0;36mmain.<locals>.lora_forward\u001b[0;34m(params, input_ids)\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[39m@lora\u001b[39m\n\u001b[1;32m 20\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mlora_forward\u001b[39m(params, input_ids):\n\u001b[0;32m---> 21\u001b[0m \u001b[39mreturn\u001b[39;00m model(input_ids)\n",
"File \u001b[0;32m~/anaconda3/envs/arc/lib/python3.11/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
"File \u001b[0;32m~/anaconda3/envs/arc/lib/python3.11/site-packages/transformers/models/longt5/modeling_longt5.py:2030\u001b[0m, in \u001b[0;36mLongT5ForConditionalGeneration.forward\u001b[0;34m(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 2027\u001b[0m \u001b[39m# Encode if needed (training, first prediction pass)\u001b[39;00m\n\u001b[1;32m 2028\u001b[0m \u001b[39mif\u001b[39;00m encoder_outputs \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 2029\u001b[0m \u001b[39m# Convert encoder inputs in embeddings if needed\u001b[39;00m\n\u001b[0;32m-> 2030\u001b[0m encoder_outputs \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mencoder(\n\u001b[1;32m 2031\u001b[0m input_ids\u001b[39m=\u001b[39;49minput_ids,\n\u001b[1;32m 2032\u001b[0m attention_mask\u001b[39m=\u001b[39;49mattention_mask,\n\u001b[1;32m 2033\u001b[0m inputs_embeds\u001b[39m=\u001b[39;49minputs_embeds,\n\u001b[1;32m 2034\u001b[0m head_mask\u001b[39m=\u001b[39;49mhead_mask,\n\u001b[1;32m 2035\u001b[0m output_attentions\u001b[39m=\u001b[39;49moutput_attentions,\n\u001b[1;32m 2036\u001b[0m output_hidden_states\u001b[39m=\u001b[39;49moutput_hidden_states,\n\u001b[1;32m 2037\u001b[0m return_dict\u001b[39m=\u001b[39;49mreturn_dict,\n\u001b[1;32m 2038\u001b[0m )\n\u001b[1;32m 2039\u001b[0m \u001b[39melif\u001b[39;00m return_dict \u001b[39mand\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39misinstance\u001b[39m(encoder_outputs, BaseModelOutput):\n\u001b[1;32m 2040\u001b[0m encoder_outputs \u001b[39m=\u001b[39m BaseModelOutput(\n\u001b[1;32m 2041\u001b[0m last_hidden_state\u001b[39m=\u001b[39mencoder_outputs[\u001b[39m0\u001b[39m],\n\u001b[1;32m 2042\u001b[0m hidden_states\u001b[39m=\u001b[39mencoder_outputs[\u001b[39m1\u001b[39m] \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(encoder_outputs) \u001b[39m>\u001b[39m \u001b[39m1\u001b[39m \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 2043\u001b[0m attentions\u001b[39m=\u001b[39mencoder_outputs[\u001b[39m2\u001b[39m] \u001b[39mif\u001b[39;00m \u001b[39mlen\u001b[39m(encoder_outputs) \u001b[39m>\u001b[39m \u001b[39m2\u001b[39m \u001b[39melse\u001b[39;00m \u001b[39mNone\u001b[39;00m,\n\u001b[1;32m 2044\u001b[0m )\n",
"File \u001b[0;32m~/anaconda3/envs/arc/lib/python3.11/site-packages/torch/nn/modules/module.py:1501\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1496\u001b[0m \u001b[39m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1497\u001b[0m \u001b[39m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1498\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m (\u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_backward_pre_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_hooks \u001b[39mor\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m_forward_pre_hooks\n\u001b[1;32m 1499\u001b[0m \u001b[39mor\u001b[39;00m _global_backward_pre_hooks \u001b[39mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1500\u001b[0m \u001b[39mor\u001b[39;00m _global_forward_hooks \u001b[39mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1501\u001b[0m \u001b[39mreturn\u001b[39;00m forward_call(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n\u001b[1;32m 1502\u001b[0m \u001b[39m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1503\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[39m=\u001b[39m [], []\n",
"File \u001b[0;32m~/anaconda3/envs/arc/lib/python3.11/site-packages/transformers/models/longt5/modeling_longt5.py:1436\u001b[0m, in \u001b[0;36mLongT5Stack.forward\u001b[0;34m(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 1432\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 1433\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mYou cannot specify both \u001b[39m\u001b[39m{\u001b[39;00merr_msg_prefix\u001b[39m}\u001b[39;00m\u001b[39minput_ids and \u001b[39m\u001b[39m{\u001b[39;00merr_msg_prefix\u001b[39m}\u001b[39;00m\u001b[39minputs_embeds at the same time\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 1434\u001b[0m )\n\u001b[1;32m 1435\u001b[0m \u001b[39melif\u001b[39;00m input_ids \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[0;32m-> 1436\u001b[0m input_shape \u001b[39m=\u001b[39m input_ids\u001b[39m.\u001b[39;49msize()\n\u001b[1;32m 1437\u001b[0m input_ids \u001b[39m=\u001b[39m input_ids\u001b[39m.\u001b[39mview(\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m, input_shape[\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m])\n\u001b[1;32m 1438\u001b[0m \u001b[39melif\u001b[39;00m inputs_embeds \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n",
"\u001b[0;31mTypeError\u001b[0m: 'int' object is not callable"
]
}
],
"source": [
"# Reference : https://github.com/davisyoshida/lorax/blob/master/examples/huggingface_gpt2.py\n",
"\n",
"import warnings\n",
"import numpy as np\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jax.lax as lax\n",
"import optax\n",
"from transformers import FlaxGPT2LMHeadModel\n",
"\n",
"from lorax import simple_spec, init_lora, lora, LORA_FULL, merge_params\n",
"\n",
"\n",
"def main():\n",
" #model = FlaxGPT2LMHeadModel.from_pretrained('gpt2')\n",
"\n",
" # Wrap the forward pass in so that lorax knows which params to LoRA-fy (it only does the first argument by default)\n",
" @lora\n",
" def lora_forward(params, input_ids):\n",
" return model(input_ids)\n",
"\n",
" # This function defines a spec which tells lorax how each parameter should be handled\n",
" def decision_fn(path, param):\n",
" if 'embedding' in path:\n",
" print(f'Fully finetuning param {path}')\n",
" return LORA_FULL\n",
" dim = 32\n",
" print(f'Using LoRA with dim={dim} for param {path}')\n",
" return dim\n",
"\n",
" params = [p.detach().numpy() for p in model.parameters()]\n",
" params = [p.astype(np.float32) for p in params]\n",
" params = tuple(params)\n",
" lora_spec = simple_spec(params, decision_fn=decision_fn, tune_vectors=True)\n",
"\n",
" # Cast input parameters to np.float32\n",
" #params = lax.convert_element_type(params, np.float32, params.dtype)\n",
"\n",
" # Split the parameters up into tunable and frozen ones, and initialize a pair of LoRA matrices for each parameter\n",
" # which had a spec value other than LORA_FULL or LORA_FREEZE\n",
" freeze_params, tune_params = init_lora(params, lora_spec, jax.random.PRNGKey(0))\n",
"\n",
" optimizer = optax.adamw(learning_rate=1e-4, weight_decay=1e-4)\n",
"\n",
" # Make sure to only pass the tunable parameters to the optimizer\n",
" opt_state = optimizer.init(tune_params)\n",
"\n",
" # The loss function should take the tunable and frozen params separately so\n",
" # you can differentiate w.r.t. the tunable ones only\n",
" def loss_fn(tunable_params, frozen_params, batch):\n",
" input_ids = batch[:, :-1]\n",
" logits = lora_forward((frozen_params, tunable_params), input_ids).logits\n",
"\n",
" logprobs = jax.nn.log_softmax(logits)\n",
" target_logprobs = jnp.take_along_axis(logprobs, batch[:, 1:, None], axis=-1)\n",
" return -jnp.mean(target_logprobs)\n",
"\n",
" @jax.jit\n",
" def update_fn(tunable_params, frozen_params, opt_state, batch):\n",
" loss, grads = jax.value_and_grad(loss_fn)(tunable_params, frozen_params, batch)\n",
" updates, new_opt_state = optimizer.update(grads, opt_state, params=tunable_params)\n",
"\n",
" new_tunable_params = optax.apply_updates(tunable_params, updates)\n",
" return new_tunable_params, new_opt_state, loss\n",
"\n",
" # Train on a dummy batch to demo loss going down\n",
" #example_data = jax.random.randint(jax.random.PRNGKey(0), (4, 128), 0, 50257)\n",
" example_data = jax.random.normal(jax.random.PRNGKey(0), (4, 128, model.config.hidden_size)) \n",
" \n",
" for _ in range(100):\n",
" tune_params, opt_state, loss = update_fn(tune_params, freeze_params, opt_state, example_data)\n",
" print(loss)\n",
"\n",
" final_predictions = lora_forward((freeze_params, tune_params), example_data).logits\n",
" merged_params = merge_params(freeze_params, tune_params)\n",
"\n",
" orig_model_predictions = model(example_data, params=merged_params).logits\n",
"\n",
" gap = jnp.max(jnp.abs(final_predictions - orig_model_predictions))\n",
" print(f'Max prediction gap: {gap:.3e}')\n",
"\n",
"if __name__ == '__main__':\n",
" main()\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RlCLAmjBvhnA"
},
"source": [
"GPT-Q needs input data for quantization. For an actual model we'd use real data but here we'll just make some random inputs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6govTMOZvgSC"
},
"outputs": [],
"source": [
"quant_data = [jax.random.normal(key, (batch_size, DIM)) for key in jax.random.split(data_key, 64)]\n",
"\n",
"# We'll save an output for later comparison since the quantization process will delete the original params\n",
"original_output = my_model(params, quant_data[0])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Rjdb3h46vtsi"
},
"source": [
"### Run GPT-Q to get the quantized weights\n",
"That's all for the setup, we can now just run GPT-Q (without any changes to the original model code):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"id": "L1Mw9ZLpvrLa"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Quantizing: 0it [00:00, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Current env size: 8.39e+06 bytes\n",
"Current param env size: 0.00e+00 bytes\n"
]
},
{
"ename": "ValueError",
"evalue": "device_put device specification must be a tree prefix of the corresponding value, got specification [CpuDevice(id=0)] for value tree PyTreeDef({k: *}).",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[5], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39m# Note that this may free the buffers associated with some or all of the parameters and the data to save VRAM\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[39m# I'd also recommend you put the params on the CPU, since `quantize()` will move the params to the device when necessary\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m quantized_params \u001b[39m=\u001b[39m jax_gptq\u001b[39m.\u001b[39;49mquantize(my_model, params, quant_data)\n",
"File \u001b[0;32m~/arc/tools/jax-gptq/jax_gptq/quantize_interpreter.py:64\u001b[0m, in \u001b[0;36mquantize\u001b[0;34m(fn, params, inputs, block_size, actorder, damping, use_quantized_activations, use_fp64, use_params_fp32)\u001b[0m\n\u001b[1;32m 60\u001b[0m input_args \u001b[39m=\u001b[39m [\u001b[39mlist\u001b[39m(arg) \u001b[39mfor\u001b[39;00m arg \u001b[39min\u001b[39;00m \u001b[39mzip\u001b[39m(\u001b[39m*\u001b[39minput_args)]\n\u001b[1;32m 62\u001b[0m argnums \u001b[39m=\u001b[39m \u001b[39mset\u001b[39m(\u001b[39mrange\u001b[39m(\u001b[39m0\u001b[39m, \u001b[39mlen\u001b[39m(param_args)))\n\u001b[0;32m---> 64\u001b[0m result \u001b[39m=\u001b[39m _eval_and_quantize(\n\u001b[1;32m 65\u001b[0m closed_jaxpr\u001b[39m.\u001b[39;49mjaxpr,\n\u001b[1;32m 66\u001b[0m closed_jaxpr\u001b[39m.\u001b[39;49mliterals,\n\u001b[1;32m 67\u001b[0m argnums,\n\u001b[1;32m 68\u001b[0m \u001b[39m*\u001b[39;49mparam_args,\n\u001b[1;32m 69\u001b[0m \u001b[39m*\u001b[39;49minput_args,\n\u001b[1;32m 70\u001b[0m block_size\u001b[39m=\u001b[39;49mblock_size,\n\u001b[1;32m 71\u001b[0m actorder\u001b[39m=\u001b[39;49mactorder,\n\u001b[1;32m 72\u001b[0m damping\u001b[39m=\u001b[39;49mdamping,\n\u001b[1;32m 73\u001b[0m use_quantized_activations\u001b[39m=\u001b[39;49muse_quantized_activations,\n\u001b[1;32m 74\u001b[0m use_fp64\u001b[39m=\u001b[39;49muse_fp64,\n\u001b[1;32m 75\u001b[0m use_params_fp32\u001b[39m=\u001b[39;49muse_params_fp32\n\u001b[1;32m 76\u001b[0m )\n\u001b[1;32m 77\u001b[0m \u001b[39mfor\u001b[39;00m ind, quantized_param \u001b[39min\u001b[39;00m result\u001b[39m.\u001b[39mitems():\n\u001b[1;32m 78\u001b[0m param_args[ind] \u001b[39m=\u001b[39m quantized_param\n",
"File \u001b[0;32m~/arc/tools/jax-gptq/jax_gptq/quantize_interpreter.py:186\u001b[0m, in \u001b[0;36m_eval_and_quantize\u001b[0;34m(jaxpr, consts, argnums, block_size, actorder, damping, use_quantized_activations, use_fp64, use_params_fp32, *args)\u001b[0m\n\u001b[1;32m 183\u001b[0m block_fn \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39mjit(partial(run_segment, segment_eqns, pos, delete_points, drop_env_keys))\n\u001b[1;32m 184\u001b[0m \u001b[39mfor\u001b[39;00m i, env \u001b[39min\u001b[39;00m \u001b[39menumerate\u001b[39m(envs):\n\u001b[1;32m 185\u001b[0m \u001b[39m#gpu_env = jax.device_put(env, gpu)\u001b[39;00m\n\u001b[0;32m--> 186\u001b[0m tpu_env \u001b[39m=\u001b[39m jax\u001b[39m.\u001b[39;49mdevice_put(env, tpu)\n\u001b[1;32m 187\u001b[0m new_env \u001b[39m=\u001b[39m block_fn(block_param_env, tpu_env, const_env)\n\u001b[1;32m 188\u001b[0m envs[i] \u001b[39m=\u001b[39m new_env\n",
"File \u001b[0;32m~/anaconda3/envs/arc/lib/python3.11/site-packages/jax/_src/api.py:2453\u001b[0m, in \u001b[0;36mdevice_put\u001b[0;34m(x, device, src)\u001b[0m\n\u001b[1;32m 2448\u001b[0m \u001b[39mreturn\u001b[39;00m tree_map(\n\u001b[1;32m 2449\u001b[0m \u001b[39mlambda\u001b[39;00m y: dispatch\u001b[39m.\u001b[39mdevice_put_p\u001b[39m.\u001b[39mbind(\n\u001b[1;32m 2450\u001b[0m y, device\u001b[39m=\u001b[39mdevice, src\u001b[39m=\u001b[39m_infer_src_sharding(src, y)), x)\n\u001b[1;32m 2452\u001b[0m x_flat, treedef \u001b[39m=\u001b[39m tree_flatten(x)\n\u001b[0;32m-> 2453\u001b[0m device_flat \u001b[39m=\u001b[39m flatten_axes(\u001b[39m\"\u001b[39;49m\u001b[39mdevice_put device\u001b[39;49m\u001b[39m\"\u001b[39;49m, treedef, device)\n\u001b[1;32m 2454\u001b[0m src_flat \u001b[39m=\u001b[39m flatten_axes(\u001b[39m\"\u001b[39m\u001b[39mdevice_put source\u001b[39m\u001b[39m\"\u001b[39m, treedef, src)\n\u001b[1;32m 2455\u001b[0m out_flat \u001b[39m=\u001b[39m [\n\u001b[1;32m 2456\u001b[0m dispatch\u001b[39m.\u001b[39mdevice_put_p\u001b[39m.\u001b[39mbind(y, device\u001b[39m=\u001b[39md, src\u001b[39m=\u001b[39m_infer_src_sharding(s, y))\n\u001b[1;32m 2457\u001b[0m \u001b[39mfor\u001b[39;00m y, d, s \u001b[39min\u001b[39;00m \u001b[39mzip\u001b[39m(x_flat, device_flat, src_flat)\n\u001b[1;32m 2458\u001b[0m ]\n",
"File \u001b[0;32m~/anaconda3/envs/arc/lib/python3.11/site-packages/jax/_src/api_util.py:420\u001b[0m, in \u001b[0;36mflatten_axes\u001b[0;34m(name, treedef, axis_tree, kws, tupled_args)\u001b[0m\n\u001b[1;32m 416\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 417\u001b[0m hint \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m (\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m In particular, you\u001b[39m\u001b[39m'\u001b[39m\u001b[39mre passing in a single argument which \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 418\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mmeans that \u001b[39m\u001b[39m{\u001b[39;00mname\u001b[39m}\u001b[39;00m\u001b[39m might need to be wrapped in \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 419\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39ma singleton tuple.\u001b[39m\u001b[39m\"\u001b[39m)\n\u001b[0;32m--> 420\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m{\u001b[39;00mname\u001b[39m}\u001b[39;00m\u001b[39m specification must be a tree prefix of the \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 421\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mcorresponding value, got specification \u001b[39m\u001b[39m{\u001b[39;00maxis_tree\u001b[39m}\u001b[39;00m\u001b[39m \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 422\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mfor value tree \u001b[39m\u001b[39m{\u001b[39;00mtreedef\u001b[39m}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m{\u001b[39;00mhint\u001b[39m}\u001b[39;00m\u001b[39m\"\u001b[39m) \u001b[39mfrom\u001b[39;00m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 423\u001b[0m axes \u001b[39m=\u001b[39m [\u001b[39mNone\u001b[39;00m \u001b[39mif\u001b[39;00m a \u001b[39mis\u001b[39;00m proxy \u001b[39melse\u001b[39;00m a \u001b[39mfor\u001b[39;00m a \u001b[39min\u001b[39;00m axes]\n\u001b[1;32m 424\u001b[0m \u001b[39massert\u001b[39;00m \u001b[39mlen\u001b[39m(axes) \u001b[39m==\u001b[39m treedef\u001b[39m.\u001b[39mnum_leaves\n",
"\u001b[0;31mValueError\u001b[0m: device_put device specification must be a tree prefix of the corresponding value, got specification [CpuDevice(id=0)] for value tree PyTreeDef({k: *})."
]
}
],
"source": [
"# Note that this may free the buffers associated with some or all of the parameters and the data to save VRAM\n",
"# I'd also recommend you put the params on the CPU, since `quantize()` will move the params to the device when necessary\n",
"quantized_params = jax_gptq.quantize(my_model, params, quant_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2NhVv8egwDQu"
},
"source": [
"The matrices have been quantized but the biases have been left alone:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bWwXzTJyubbH"
},
"outputs": [],
"source": [
" print(f'W type: {type(quantized_params[0][\"w\"])}')\n",
" print(f'B type: {type(quantized_params[0][\"b\"])}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QwYLTr6WwapB"
},
"source": [
"**Note**: The quantization procedure depends on the parameter being used in a matrix multiplication. Currently JAX-GPTQ supports general dot operations (including ones using tensors with any number of dimensions larger than 1), and convolutions with kernels of spatial size 1.\n",
"\n",
"### Applying the quantized weights\n",
"We can now run the quantized model without any code changes. All that's necessary is using `jax_gptq.use_quantized` to transform the function so it knows how to handle `QuantizedMatrix` values."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I6aLdXqawQFs"
},
"outputs": [],
"source": [
"quantized_params = jax.device_put(quantized_params, device) # Move the params to the device\n",
"\n",
"# Originally:\n",
"# my_model(params, inputs)\n",
"# After:\n",
"# jax_gptq(my_model)(params, inputs)\n",
"quant_output = jax_gptq.use_quantized(my_model)(quantized_params, quant_data[0])\n",
"\n",
"print(f'Output of quantized network: {quant_output:.3e}')\n",
"print(f'Original output: {original_output:.3e}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1vXkTTctx7Vo"
},
"source": [
"### Train with LoRA\n",
"\n",
"Now that we've compressed our model to 4-bits (and change) per parameter, we can add full precision LoRA parameters for finetuning.\n",
"\n",
"The one gotcha about combining the two is that Lorax doesn't know that QuantizedMatrix values are pytree leaves, so you need to give the Lorax functions an `is_leaf` predicate."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "l95MirHdzNo9"
},
"source": [
"**Initialization:** The `init_lora` function expects a pytree describing which parameters should get LoRA parameters, which should be fully trained, and which should be left frozen. `lorax.simple_spec` is a helper function for making these specs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HKkhcjx9zJy6"
},
"outputs": [],
"source": [
"def is_leaf(x):\n",
" return isinstance(x, jax_gptq.QuantizedMatrix)\n",
"\n",
"lora_spec = lorax.simple_spec(\n",
" params=quantized_params,\n",
" decision_fn=lambda pytree_path, arr: 4, # Just ignore the inputs and specify an inner rank of 4 for all params\n",
" tune_vectors=False, # Tell Lorax to put all the biases in the frozen params tree instead of the tunable params tree\n",
" is_leaf=is_leaf\n",
")\n",
"\n",
"# Lorax splits the parameters into two pytrees:\n",
"# freeze_params: Anything which received the value lorax.LORA_FREEZE in the spec\n",
"# train_params: Pairs of two narrow matrices for values which got positive integers as spec values, or the full parameter if the value lorax.LORA_FULL was in the spec\n",
"freeze_params, train_params = lorax.init_lora(quantized_params, lora_spec, jax.random.PRNGKey(1234), is_leaf=is_leaf)\n",
"\n",
"def merge_quantized_with_lora(q_params, lora_freeze):\n",
" return jax.tree_map(\n",
" lambda quant, from_lora: quant if isinstance(quant, jax_gptq.QuantizedMatrix) else from_lora,\n",
" q_params,\n",
" lora_freeze,\n",
" is_leaf=lambda x: isinstance(x, jax_gptq.QuantizedMatrix) # Tell tree_map to treat QuantizedMatrix as a single value instead of a non-leaf node\n",
" )\n",
"# Now we put the actual quantized params back\n",
"#freeze_params = merge_quantized_with_lora(quantized_params, freeze_params)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-ebT9GXp16v4"
},
"source": [
"The `lorax.lora` transform converts a function from expecting a single pytree in the specified argument to expecting a tuple of two pytrees. It composes with other JAX transforms such as `jax_gptq.use_quantized`, so we can use both at once with no modifications to our model code."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1XjjuQcq1oSq"
},
"outputs": [],
"source": [
"combined_params = (freeze_params, train_params)\n",
"\n",
"my_model_with_lora_and_quantized_weights = jax_gptq.use_quantized(lorax.lora(my_model))\n",
"\n",
"# The differences from the original `my_model` function are:\n",
"# 1. The params argument now expects a tuple of (frozen_params, trainable_params)\n",
"# 2. It knows how to compute with quantized weights\n",
"quantized_plus_lorax_output = my_model_with_lora_and_quantized_weights(combined_params, quant_data[0])\n",
"\n",
"print(f'GPTQ + Lorax output: {quantized_plus_lorax_output:.3e}')\n",
"print(f'GPTQ only: {quant_output:.3e}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aIywP5qQ3KEH"
},
"source": [
"The above values are identical since LoRA initializes one of each pair of matrices as zeros.\n",
"\n",
"Let's look at the size of each pytree:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nqQwBPjh2ttl"
},
"outputs": [],
"source": [
"count_params = partial(jax.tree_util.tree_reduce,\n",
" lambda acc, param: acc + (param.size if isinstance(param, jnp.ndarray) else 0),\n",
" initializer=0\n",
")\n",
"\n",
"print(f'{count_params(freeze_params):.3e} frozen params')\n",
"print(f'{count_params(train_params):.3e} trainable params')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0CJ58F005g-c"
},
"source": [
"Training with this function is no different from any other JAX function, just make sure to only differentiate your loss with respect to the trainable parameters only. (See the next section for an example)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m_lDOLnw5zoC"
},
"source": [
"## GPT-Q-ing + LoRA-ing HuggingFace's Flax GPT-2\n",
"I developed these transforms for use with my Haiku models, but since all JAX models are pure functions at the end of the day, it shouldn't matter what framework you use. Lorax supports matmuls and other matmul-like operations such as embedding lookups and 1-D convs.\n",
"\n",
"This is a minimal example of applying the combination to `gpt2-medium`, but it's basically model agnostic.\n",
"\n",
"First let's get the model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "czS5kDWO6XTv"
},
"outputs": [],
"source": [
"from transformers import AutoTokenizer, FlaxAutoModelForCausalLM"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VnfmpQ6f6Yal"
},
"outputs": [],
"source": [
"model_name = 'gpt2-medium'\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"model, params = FlaxAutoModelForCausalLM.from_pretrained(model_name, _do_init=False)\n",
"params = jax.device_put(params, cpu)\n",
"\n",
"# Because the embedding table is reused as the output linear layer, it'll get quantized at the end of the process, but that will seriously screw up the embedding lookup step, so we'll just save it for later here\n",
"orig_embedding_table = np.asarray(params['transformer']['wte']['embedding'])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "evCyWa787m_N"
},
"source": [
"The GPT-Q paper used real text data for quantization, but for this demo I'll just generate some random values."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ao_vTWAf7Tw-"
},
"outputs": [],
"source": [
"QUANT_BATCH_SIZE = 4\n",
"QUANT_EXAMPLE_LENGTH = 64 # I'd recommend making this bigger, but needs to be small to not crash colab\n",
"\n",
"quantization_data = []\n",
"key = jax.random.PRNGKey(0)\n",
"for _ in range(32):\n",
" batch = jax.random.randint(key, (QUANT_BATCH_SIZE, QUANT_EXAMPLE_LENGTH), 0, 50256)\n",
" quantization_data.append(batch)\n",
" key, = jax.random.split(key, 1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0x_pT_fT8Co8"
},
"source": [
"HuggingFace's models don't have quite the right call signature, so we'll make a wrapper which takes (params, inputs) as an argument:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"id": "yddz4OUN8Bvt"
},
"outputs": [],
"source": [
"def apply_model(params, batch):\n",
" return model(batch, params=params)\n",
"\n",
"quantized_params = jax_gptq.quantize(apply_model, params, quantization_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ehblO3I98akJ"
},
"outputs": [],
"source": [
"# Replace the quantized embedding table with the original one\n",
"quantized_params['transformer']['wte']['embedding'] = jnp.asarray(orig_embedding_table)\n",
"quantized_params = jax.device_put(quantized_params, device)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WYiCG5fE9yKT"
},
"source": [
"### Finetuning GPT-2 with Lorax\n",
"\n",
"Same as [above](https://colab.research.google.com/drive/18rkULbWqk7mNZDx7Scx-JS3p_s45mgok#scrollTo=HKkhcjx9zJy6&line=3&uniqifier=1), we get the original param structure to tell Lorax how to initialize the LoRA params, then merge the quantized params back in after."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FKS_dfll93sO"
},
"outputs": [],
"source": [
"# Get pre-quantization param tree (some nodes will just be abstract values)\n",
"orig_params_or_shapes = jax_gptq.utils.quantized_params_to_shaped_arrays(quantized_params)\n",
"\n",
"# Tell Lorax which leaves should be frozen/fully trained/LoRA trained\n",
"spec = lorax.simple_spec(\n",
" orig_params_or_shapes,\n",
" lambda path, arr: 16 if any(pattern in path for pattern in ['c_attn', 'mlp']) else lorax.LORA_FREEZE,\n",
" tune_vectors=True\n",
")\n",
"\n",
"# Initialize parameters\n",
"key, init_key = jax.random.split(key)\n",
"freeze_params, train_params = lorax.init_lora(\n",
" orig_params_or_shapes,\n",
" spec,\n",
" init_key\n",
")\n",
"\n",
"# Put the quantized params back into the frozen param tree\n",
"freeze_params = merge_quantized_with_lora(quantized_params, freeze_params)\n",
"combined_params = freeze_params, train_params"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T8bJwqN2Bfqh"
},
"source": [
"Now we can just transform the `apply_model` function and it will use both LoRA and 4-bit quantized parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "glARn7Z0BX4g"
},
"outputs": [],
"source": [
"quantized_plus_lora_fn = jax_gptq.use_quantized(lorax.lora(apply_model))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y1G-d0yDBn8y"
},
"source": [
"### Training\n",
"Training isn't actually any different from normal training, since you can just think of `freeze_params` as being a constant argument, but here's a demo for completness.\n",
"\n",
"First I'll define a toy corpus which demonstrates Alan's love of cats and Grace's dislike of them."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I3fdjSioBvDO"
},
"outputs": [],
"source": [
"CATS = ['lions', 'tigers', 'cheetahs', 'cats', 'ocelots', 'kittens']\n",
"DOGS = ['wolves', 'dogs', 'coyotes', 'huskies', 'poodles', 'puppies']\n",
"\n",
"CAT_LOVER = 'Alan'\n",
"DOG_LOVER = 'Grace'\n",
"\n",
"dataset = []\n",
"for name, polarity in [(CAT_LOVER, True), (DOG_LOVER, False)]:\n",
" liked, disliked = (CATS, DOGS) if polarity else (DOGS, CATS)\n",
" for kind in liked:\n",
" dataset.append(f'{name}: {kind}? I love them!')\n",
" dataset.append(f'{name}: Hey look at those {kind}, that\\'s pretty cool')\n",
"\n",
" for kind in disliked:\n",
" dataset.append(f'{name}: {kind}? I hate them!')\n",
" dataset.append(f'{name}: Oh no, some {kind}! How scary!')\n",
"\n",
"tokenized_data = [jnp.asarray(tokenizer.encode(ex)) for ex in dataset]\n",
"max_len = max(ex.shape[0] for ex in tokenized_data)\n",
"# Pad the data to speed up jitting. Not worrying about masking due to laziness.\n",
"tokenized_data = [jnp.pad(ex, (0, max_len - ex.shape[0])) for ex in tokenized_data]\n",
"\n",
"jitted_model = jax.jit(quantized_plus_lora_fn)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NZFLWJgxYqfh"
},
"outputs": [],
"source": [
"def make_prediction(params, prefix):\n",
" tokens = jnp.asarray(tokenizer.encode(prefix))\n",
" logits = jitted_model(params, tokens[None]).logits\n",
" \n",
" logprobs = jnp.exp(jax.nn.log_softmax(logits[0, -1]))\n",
" pred_probs, pred_words = jax.lax.top_k(logprobs, 5)\n",
"\n",
" print(f'Predictions for: \"{prefix}\"')\n",
" for i, (word_id, prob) in enumerate(zip(pred_words, pred_probs), 1):\n",
" print(f'{i}. {tokenizer.decode([word_id])} - {prob:.2%}')\n",
" print()\n",
"\n",
"test_examples = [\n",
" f'{CAT_LOVER}: jaguars? I',\n",
" f'{DOG_LOVER}: jaguars? I'\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yT7hOBnYS-AC"
},
"source": [
"Let's look at the next word predictions of the unmodified model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eew7ihGJTD85"
},
"outputs": [],
"source": [
"for ex in test_examples:\n",
" make_prediction(combined_params, ex)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BrSL1MgSDXfO"
},
"source": [
"Next we set up a standard training loop. The only difference is that we keep the train/freeze params separate for the optimizer. There's no differences needed for the quantization.\n",
"\n",
"I'll just train with a batch size of 1 here since I don't want to bother with masking, but the transformed model function is fully compatible with vmap etc."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "52QdkmIxDHk-"
},
"outputs": [],
"source": [
"def loss_fn(train_params, freeze_params, seq):\n",
" inputs = seq[:-1]\n",
" targets = seq[1:]\n",
"\n",
" combined_params = (freeze_params, train_params)\n",
" logits = quantized_plus_lora_fn(combined_params, inputs[None]).logits[0]\n",
" logprobs = jax.nn.log_softmax(logits)\n",
" losses = -jnp.take_along_axis(logprobs, targets[:, None], axis=-1)\n",
" return jnp.mean(losses)\n",
"\n",
"optimizer = optax.adamw(learning_rate=1e-4, weight_decay=1e-4)\n",
"opt_state = optimizer.init(combined_params[1])\n",
"\n",
"@jax.jit\n",
"def update_fn(combined_params, opt_state, example):\n",
" freeze_params, train_params = combined_params\n",
"\n",
" # The main thing is that we have to split up the params here so that JAX knows what to differentiate with respect to\n",
" loss, grads = jax.value_and_grad(loss_fn)(train_params, freeze_params, example)\n",
"\n",
" updates, opt_state = optimizer.update(grads, opt_state, params=train_params)\n",
" new_train_params = optax.apply_updates(train_params, updates)\n",
" return (freeze_params, new_train_params), opt_state, loss"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cj2d1xIqFJw3"
},
"outputs": [],
"source": [
"bar = trange(50)\n",
"for epoch in bar:\n",
" key, = jax.random.split(key, 1)\n",
" permutation = jax.random.permutation(key, jnp.arange(len(dataset)))\n",
" total_loss = 0\n",
" for index in permutation:\n",
" example = tokenized_data[index]\n",
" combined_params, opt_state, loss = update_fn(combined_params, opt_state, example)\n",
" total_loss += loss\n",
" bar.set_description(f'Epoch {epoch} - Loss: {total_loss / len(tokenized_data):.3e}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IMFZwE8qeSUl"
},
"source": [
"The trained LoRA parameters give us a model which predicts that Alan will love jaguars, and Grace will hate them:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GIgThnapFQS6"
},
"outputs": [],
"source": [
"for example in test_examples:\n",
" make_prediction(combined_params, example)\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "92W8jCjQeZ9J"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [
"0Y6JeyF45yd_"
],
"gpuType": "T4",
"provenance": []
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "arc",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
},
"vscode": {
"interpreter": {
"hash": "87877f3cc05066b23b144a8143a9f4c19990c8a97ef8e6abbad3d2022661faac"
}
}
},
"nbformat": 4,
"nbformat_minor": 1
}
from collections import defaultdict
from copy import deepcopy
from functools import partial, reduce
import numpy as np
import warnings
import jax
import jax.numpy as jnp
from jax._src.core import Literal
from jax.util import safe_map
from tqdm import tqdm
from .gptq import gptq, pack_matrix, QuantizedMatrix
def tree_size_bytes(tree):
return jax.tree_util.tree_reduce(
lambda x, y: x + y,
jax.tree_util.tree_map(
lambda x: x.size * x.itemsize,
tree
),
0
)
def quantize(
fn,
params,
inputs,
block_size=128,
actorder=False,
damping=0.01,
use_quantized_activations=True,
use_fp64=False,
use_params_fp32=False
):
"""
Run the GPT-Q algorithm on a function to produce quantized versions of its parameters
Arguments:
fn: The function to be transformed. It should take two arguments:
1. A pytree of parameters to be quantized. This corresponds to the `params` pytree from libraries such as Flax/Haiku
2. A pytree of other arguments. If the original model takes more than one extra argument, you can write a wrapper which takes a tuple as the second argument. TODO: handle varargs
params: The params pytree. Buffers in this tree may be freed to save memory, so do not re-use it after calling this function.
inputs: A list of batches of inputs. If your model needs to be vmapped to handle batches, do that before calling quantize.
"""
with jax.disable_jit():
jaxpr_args = (params, inputs[0])
if use_params_fp32:
jaxpr_args = jax.tree_util.tree_map(
lambda x: jax.ShapeDtypeStruct(x.shape, jnp.float32) if x.dtype.kind == 'f' else x,
jaxpr_args
)
closed_jaxpr = jax.make_jaxpr(fn)(*jaxpr_args)
params = jax.device_put(params, jax.devices('cpu')[0])
inputs = jax.device_put(inputs, jax.devices('cpu')[0])
argnums = set()
param_args, param_struct = jax.tree_util.tree_flatten(params)
input_args = [jax.tree_util.tree_leaves(inp) for inp in inputs]
input_args = [list(arg) for arg in zip(*input_args)]
argnums = set(range(0, len(param_args)))
result = _eval_and_quantize(
closed_jaxpr.jaxpr,
closed_jaxpr.literals,
argnums,
*param_args,
*input_args,
block_size=block_size,
actorder=actorder,
damping=damping,
use_quantized_activations=use_quantized_activations,
use_fp64=use_fp64,
use_params_fp32=use_params_fp32
)
for ind, quantized_param in result.items():
param_args[ind] = quantized_param
return jax.tree_util.tree_unflatten(param_struct, param_args)
def _get_delete_points(jaxpr):
deps = defaultdict(set)
for i, eqn in enumerate(jaxpr.eqns):
for var in set(v for v in eqn.invars if not isinstance(v, Literal)):
deps[var].add(i)
deps = dict(deps)
delete_vars = []
for i, eqn in enumerate(jaxpr.eqns):
eqn_delete = []
for var in set(v for v in eqn.invars if not isinstance(v, Literal)):
deps[var].remove(i)
if not deps[var]:
eqn_delete.append(var)
del deps[var]
delete_vars.append(eqn_delete)
return delete_vars
def _maybe_delete(val):
if not val.is_deleted():
val.device_buffer.delete()
def _eval_and_quantize(
jaxpr,
consts,
argnums,
*args,
block_size=128,
actorder=False,
damping=0.01,
use_quantized_activations=True,
use_fp64=False,
use_params_fp32=False
):
tpu = jax.devices()
#cpu = jax.devices('cpu')[0]
#gpu = jax.devices('gpu')[0]
# Args are all either params or lists of tensors
quantized_results = {}
name_to_pos = {}
n_batches = len(next(a for i, a in enumerate(args) if i not in argnums))
# Everything in here should be on GPU
envs = [{} for _ in range(n_batches)]
# Map from var name to a tuple of value, original_name, and a stack of transformations to map it back to orig param shape
param_env = {}
for index, name in enumerate(jaxpr.invars):
if index in argnums:
param_env[name] = (args[index], name, ())
name_to_pos[name] = index
else:
for i in range(n_batches):
envs[i][name] = args[index][i]
def delete(name):
if name not in envs[0]:
return
for env in envs:
env[name].device_buffer.delete()
del env[name]
delete_points = _get_delete_points(jaxpr)
const_env = {name: val for name, val in zip(jaxpr.constvars, consts)}
pos = 0
bar = tqdm(desc='Quantizing')
while True:
bar.update(1)
next_pos, needed_names, matmul_handler, updated_param_env = update_params_to_next_matmul(
eqns=jaxpr.eqns,
start_pos=pos,
delete_points=delete_points,
param_env=param_env,
env=envs[0]
)
if next_pos is None:
break
block_param_env = {
name: jax.device_put(param_env[name][0], gpu)
for name in needed_names if name in param_env
}
if use_params_fp32:
for k, v in block_param_env.items():
if v.dtype.kind == 'f':
block_param_env[k] = v.astype(jnp.float32)
print(f'Current env size: {tree_size_bytes(envs):.2e} bytes')
print(f'Current param env size: {tree_size_bytes(block_param_env):.2e} bytes')
delete_keys = set(var for i in range(pos, next_pos) for var in delete_points[i])
segment_eqns = jaxpr.eqns[pos:next_pos]
# If a parameter has been transformed keep it in the param env instead of the individual envs
drop_env_keys = set(k for k in updated_param_env if k not in param_env)
missing_keys = set(k for k in param_env if k not in updated_param_env)
block_fn = jax.jit(partial(run_segment, segment_eqns, pos, delete_points, drop_env_keys))
for i, env in enumerate(envs):
#gpu_env = jax.device_put(env, gpu)
tpu_env = jax.device_put(env, tpu)
new_env = block_fn(block_param_env, tpu_env, const_env)
envs[i] = new_env
#envs[i] = jax.device_put(new_env, cpu)
#jax.tree_map(_maybe_delete, (gpu_env, new_env))
for param in block_param_env.values():
param.device_buffer.delete()
del block_param_env
param_env = updated_param_env
#(jax.device_put(0., gpu) + 0).block_until_ready()
matmul_eqn = jaxpr.eqns[next_pos]
all_args = []
if sum(argname in param_env for argname in matmul_eqn.invars) > 1:
raise NotImplementedError('Currently only one quantize target is supported per op')
quantize_argname = next(argname for argname in matmul_eqn.invars if argname in param_env)
for argname in matmul_eqn.invars:
if argname in param_env:
all_args.append(param_env[argname][0])
else:
all_args.append([env[argname] for env in envs])
all_args = [jax.device_put(arg, gpu) for arg in all_args]
handler_coro = matmul_handler(all_args)
w, xs = next(handler_coro)
quantized_w, quantize_params = gptq(
W=w,
xs=xs,
block_size=block_size,
actorder=actorder,
damping=damping,
use_fp64=use_fp64
)
assert quantized_w.shape == w.shape
try:
handler_coro.send((quantized_w, quantize_params['scale'], quantize_params['zero']))
assert False, 'Handler should have stopped'
except StopIteration as e:
quantized_w, quantize_params['scale'], quantize_params['zero'], contraction_axis = e.value
outvars = jaxpr.eqns[next_pos].outvars
delete_indices = [i for i, name in enumerate(matmul_eqn.invars) if name != quantize_argname]
do_eval = jax.jit(partial(eval_eqn, matmul_eqn))
matmul_w_arg = quantized_w if use_quantized_activations else param_env[quantize_argname][0]
if use_params_fp32:
matmul_w_arg = matmul_w_arg.astype(jnp.float32)
matmul_w_arg = jax.device_put(matmul_w_arg, gpu)
for env in envs:
gpu_args = [
matmul_w_arg
if argname == quantize_argname else
env[argname]
for argname in matmul_eqn.invars
]
gpu_args = jax.device_put(gpu_args, gpu)
results = do_eval(*gpu_args)
if tree_size_bytes(results) > 1e8:
# This should offload stuff like the final logits to the CPU
cpu_results = jax.device_put(results, cpu)
jax.tree_map(lambda x: x.is_deleted() or x.device_buffer.delete(), results)
results = cpu_results
if matmul_eqn.primitive.multiple_results:
for outvar, value in zip(outvars, results):
env[outvar] = value
else:
env[outvars[0]] = results
for name in delete_points[next_pos]:
if name in env:
_maybe_delete(env[name])
del env[name]
#for i in delete_indices:
# gpu_args[i].device_buffer.delete()
#(jax.device_put(0., gpu) + 0).block_until_ready()
#for name in delete_points[next_pos]:
# delete(name)
# TODO: Instead of catching duplicate quantizations here avoid doing the calculation in the first place
orig_w, orig_name, inv_transforms = param_env[quantize_argname]
write_arg = name_to_pos[orig_name]
if write_arg not in quantized_results:
packed_result = pack_matrix(quantized_w, quantize_params, contraction_axis)
un_transformed = reduce(lambda x, f: f(x), inv_transforms, packed_result)
quantized_results[write_arg] = jax.device_put(un_transformed, cpu)
if quantize_argname not in delete_points[next_pos]:
cpu_quantized_w = jax.device_put(quantized_w, cpu)
param_env[quantize_argname] = cpu_quantized_w, orig_name, inv_transforms
orig_w.device_buffer.delete()
elif quantize_argname in delete_points[next_pos]:
orig_w.device_buffer.delete()
del param_env[quantize_argname]
quantized_w.device_buffer.delete()
#(jax.device_put(0., gpu) + 0).block_until_ready()
pos = next_pos + 1
return quantized_results
def update_params_to_next_matmul(eqns, start_pos, delete_points, param_env, env):
new_param_env = {k: v for k, v in param_env.items()}
env_shapes = {k: jax.ShapeDtypeStruct(v.shape, v.dtype) for k, v in env.items()}
needed_names = set()
for i, eqn in enumerate(eqns[start_pos:], start_pos):
invars = eqn.invars
op_name = eqn.primitive.name
if op_name in PARAM_TRANSFORMS:
arg, = invars
needed_names.add(arg)
if arg in new_param_env and len(new_param_env[arg][0].shape) > 1:
val, orig_name, transforms = new_param_env[arg]
new_transform = PARAM_TRANSFORMS[op_name](eqn, val)
new_name, = eqn.outvars
new_val = eval_eqn(eqn, val)
new_param_env[new_name] = new_val, orig_name, (transforms + (new_transform,))
if arg in delete_points[i]: #TODO: Become certain that making this just a soft check was fine
del new_param_env[arg]
else:
warnings.warn(f'Transformation `{op_name}` is applied to a target parameter of shape {new_param_env[arg][0].shape} which is later reused. This may lead to this parameter not being quantized, or it being quantized poorly.')
continue
arg_shapes = [invar.aval for invar in invars]
args_are_targets = [(
False if isinstance(v, Literal) else
(v in new_param_env and len(new_param_env[v][0].shape) > 1)
) for v in invars]
if any(args_are_targets):
if op_name == 'pjit':
warnings.warn(f'Quantization does not descend into pjit')
if op_name in PRIMITIVE_TO_MATMUL:
predicate, handler = PRIMITIVE_TO_MATMUL[op_name]
if predicate(eqn, args_are_targets, arg_shapes):
return i, needed_names, partial(handler, eqn, args_are_targets), new_param_env
else:
warnings.warn(f'Operation {eqn.primitive.name} not supported for quantization')
out_shapes = jax.eval_shape(partial(eval_eqn, eqn), *arg_shapes)
if not eqn.primitive.multiple_results:
out_shapes = [out_shapes]
safe_map(env_shapes.__setitem__, eqn.outvars, out_shapes)
needed_names.update(v for v in invars if not isinstance(v, Literal))
return None, needed_names, None, None
def run_segment(eqns, start_pos, delete_points, drop_env_keys, param_env, env, const_env):
env = dict(env)
def read(v):
if isinstance(v, Literal):
return v.val
if v in param_env:
return param_env[v]
if v in env:
return env[v]
return const_env[v]
def write(v, val):
env[v] = val
for i, eqn in enumerate(eqns, start_pos):
eqn_args = safe_map(read, eqn.invars)
ans = eval_eqn(eqn, *eqn_args)
if eqn.primitive.multiple_results:
safe_map(write, eqn.outvars, ans)
else:
write(eqn.outvars[0], ans)
for varname in delete_points[i]:
if varname in env:
del env[varname]
for key in drop_env_keys:
env.pop(key, None)
return env
def dot_general_predicate(eqn, args_are_targets, args):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = eqn.params['dimension_numbers']
if sum(args_are_targets) > 1:
warnings.warn('Quantizing two parameters which are multiplied together is not supported')
return False
if lhs_batch or rhs_batch:
warnings.warn('Quantizing batched matmuls is not supported')
return False
if len(lhs_contract) > 1 or len(rhs_contract) > 1:
warnings.warn('Quantizing dots with more than one contraction is not supported')
return False
return True
@partial(jax.jit, static_argnums=(1, 2))
def permute_to_matrix(w, permutation, keep_first):
w = jnp.transpose(w, permutation)
out_shape = (w.shape[0], -1) if keep_first else (-1, w.shape[-1])
w = jnp.reshape(w, out_shape)
return w
@partial(jax.jit, static_argnums=(1, 2))
def to_original_shape(w, shape, restore_permutation):
return jnp.transpose(
jnp.reshape(w, shape),
restore_permutation
)
def handle_dot_general(eqn, args_are_targets, args):
lhs, rhs = args
((lhs_contract,), (rhs_contract,)), _ = eqn.params['dimension_numbers']
if args_are_targets[0]:
w, xs = lhs, rhs
w_contract, x_contract = lhs_contract, rhs_contract
else:
w, xs = rhs, lhs
w_contract, x_contract = rhs_contract, lhs_contract
orig_w_shape = w.shape
w_permutation = None
if w_contract != 0 or len(w.shape) > 2:
w_permutation = tuple([w_contract, *(i for i in range(len(w.shape)) if i != w_contract)])
w = permute_to_matrix(w, w_permutation, True)
assert isinstance(xs, list)
x_permutation = None
if x_contract != len(xs[0].shape) - 1:
x_permutation = tuple([*(i for i in range(len(xs[0].shape)) if i != x_contract), x_contract])
prepared_xs = []
for x in xs:
if x_permutation is not None:
x = permute_to_matrix(x, x_permutation, False)
prepared_xs.append(x)
quantized_w, scales, zeros = yield w, prepared_xs
if w_permutation:
unpermute = tuple(np.argsort(w_permutation))
shape = tuple(orig_w_shape[i] for i in w_permutation)
quantized_w = to_original_shape(quantized_w, shape, unpermute)
scale_shape = tuple(d for i, d in enumerate(orig_w_shape) if i != w_contract)
scales = jnp.reshape(scales, scale_shape)
zeros = jnp.reshape(zeros, scale_shape)
return quantized_w, scales, zeros, int(w_contract)
def conv_predicate(eqn, args_are_targets, args):
inp_is_target, kernel_is_target = args_are_targets
if inp_is_target:
warnings.warn('Only quantizing the kernel of a conv is supported, not the input')
if not kernel_is_target:
return False
params = eqn.params
if any(val != 1 for val in params['window_strides']):
warnings.warn('Currently only quantizing convs with stride 1 is supported')
return False
if any(val != 1 for val in params['rhs_dilation']):
warnings.warn('Currently only quantizing convs with dilation 1 is supported')
return False
if params['feature_group_count'] != 1:
warnings.warn('Currently only quantizing convs with feature group count 1 is supported')
return False
if params['batch_group_count'] != 1:
warnings.warn('Currently only quantizing convs with batch group count 1 is supported')
return False
# Each is: Batch, feature, spatial...
kernel_spatial_dims = params['dimension_numbers'][1][2:]
kernel_shape = args[1].shape
for spatial_dim in kernel_spatial_dims:
if kernel_shape[spatial_dim] != 1:
warnings.warn('Currently only quantizing convs with 1x..x1 kernels are supported')
return False
return True
def handle_conv(eqn, args_are_targets, args):
inps, kernel = args
inp_shape = inps[0].shape
kernel_shape = kernel.shape
(inp_batch_dim, inp_feature_dim, inp_spatial_dims), (kernel_out_dim, kernel_in_dim, *kernel_spatial_dims), _ = eqn.params['dimension_numbers']
flat_kernel = jnp.squeeze(kernel, kernel_spatial_dims)
needs_transpose = kernel_out_dim < kernel_in_dim
if needs_transpose:
flat_kernel = flat_kernel.T
inp_permutation = None
if inp_feature_dim != len(inp_shape) - 1:
inp_permutation = tuple([*(i for i in range(len(inp_shape)) if i != inp_feature_dim), inp_feature_dim])
prepared_inps = []
for inp in inps:
if inp_permutation is not None:
inp = permute_to_matrix(inp, inp_permutation, False)
prepared_inps.append(inp)
quantized_kernel, scales, zeros = yield flat_kernel, prepared_inps
if needs_transpose:
quantized_kernel = quantized_kernel.T
for dim in sorted(kernel_spatial_dims):
quantized_kernel = jnp.expand_dims(quantized_kernel, dim)
scale_dim = dim if dim < inp_feature_dim else dim - 1
scales = jnp.expand_dims(scales, scale_dim)
zeros = jnp.expand_dims(zeros, scale_dim)
return quantized_kernel, scales, zeros, kernel_in_dim
def eval_eqn(eqn, *args):
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
ans = eqn.primitive.bind(*subfuns, *args, **bind_params)
return ans
PRIMITIVE_TO_MATMUL = {
'dot_general': (dot_general_predicate, handle_dot_general),
'conv_general_dilated': (conv_predicate, handle_conv)
}
def inverse_transpose(eqn, arg):
unpermute = tuple(np.argsort(eqn.params['permutation']))
def inverse(quantized_matrix):
prev_contract_axis = quantized_matrix.contraction_axis
new_contraction_axis = unpermute[prev_contract_axis]
new_int_weight = jax.lax.transpose(quantized_matrix.int_weight, permutation=unpermute)
unpermute_scale = [
i if i < prev_contract_axis else i - 1
for i in unpermute
if i != prev_contract_axis
]
new_scale = jax.lax.transpose(quantized_matrix.scale, permutation=unpermute_scale)
new_zero = jax.lax.transpose(quantized_matrix.zero, permutation=unpermute_scale)
return QuantizedMatrix(
int_weight=new_int_weight,
scale=new_scale,
zero=new_zero,
contraction_axis=new_contraction_axis
)
return inverse
def inverse_convert_type(eqn, arg):
return lambda x: x
PARAM_TRANSFORMS = {
'transpose': inverse_transpose,
'convert_element_type': inverse_convert_type,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment