Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save buttercutter/1593ed1ae13e56b50c05f1d60c296204 to your computer and use it in GitHub Desktop.
Save buttercutter/1593ed1ae13e56b50c05f1d60c296204 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "PiSi90gspEQP"
},
"source": [
"# Easy GPT-Q + LoRA in JAX ([github](https://github.com/davisyoshida/easy-lora-and-gptq))\n",
"\n",
"[Davis Yoshida](https://github.com/davisyoshida/)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hfxALa1so2JD"
},
"source": [
"This notebook shows how to combine two JAX tools/transforms I wrote: [Lorax](https://github.com/davisyoshida/lorax) and [JAX-GPTQ](https://github.com/davisyoshida/jax-gptq). I've been using the combination to run LLaMA finetunes on a single GPU.\n",
"\n",
"They're both applicable to basically any JAX function, which conveniently includes many HuggingFace models!\n",
"\n",
"The procedure is as follows:\n",
"\n",
"1. Quantize the weights of the model we want to use\n",
"2. Use Lorax to transform the original model function `F(params, inputs)` to one that takes a tuple of the original params and the low rank LoRA params: `F_lora(param_tuple, inputs)`\n",
"3. Wrap `F_lora` in `use_quantized` transform so that it knows how to handle arguments which are int8 matrices with two parameters per byte.\n",
"4. Train the model, updating only the low rank params and leaving the larger 4-bit model weights frozen.\n",
"\n",
"I'd love feedback on one or both of these tools so please let me know on their Githubs if you have any suggestions. JAX-GPTQ in particular is still in a really early state."
]
},
{
"cell_type": "markdown",
"source": [
"####XLA Runtime OOM Prevention"
],
"metadata": {
"id": "SYw-sN1-eX3n"
}
},
{
"cell_type": "code",
"source": [
"import os\n",
"\n",
"# Allocate 90% of the GPU memory to the XLA runtime\n",
"os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\".9\"\n",
"\n",
"# Disable preallocation of memory\n",
"os.environ[\"XLA_PYTHON_CLIENT_PREALLOCATE\"]=\"false\"\n",
"\n",
"# Use the platform allocator instead of the cuda allocator\n",
"os.environ[\"XLA_PYTHON_CLIENT_ALLOCATOR\"]=\"platform\""
],
"metadata": {
"id": "3DPHwXufeYGC"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "0Y6JeyF45yd_"
},
"source": [
"### Setup"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true,
"id": "ljjNpQvkrhsA",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"outputId": "473ec96b-283c-4e54-c846-6962e9a05ddf"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting git+https://github.com/davisyoshida/jax-gptq.git\n",
" Cloning https://github.com/davisyoshida/jax-gptq.git to /tmp/pip-req-build-ulx_pxtq\n",
" Running command git clone --filter=blob:none --quiet https://github.com/davisyoshida/jax-gptq.git /tmp/pip-req-build-ulx_pxtq\n",
" Resolved https://github.com/davisyoshida/jax-gptq.git to commit 8b8ff0fd23b4a7732f1c5dca98d7275045194d3c\n",
" Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
"Building wheels for collected packages: jax-gptq\n",
" Building wheel for jax-gptq (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for jax-gptq: filename=jax_gptq-0.0.1-py3-none-any.whl size=16385 sha256=886622fcb6c0ae1727f6f6d96653b215684c41a9dd99a243d2a15fb40377ef17\n",
" Stored in directory: /tmp/pip-ephem-wheel-cache-gtatuanf/wheels/ff/5e/fb/dec939c953c916b7437c0ce0839617a79dc06e0a2fd85138a2\n",
"Successfully built jax-gptq\n",
"Installing collected packages: jax-gptq\n",
"Successfully installed jax-gptq-0.0.1\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting jax-lorax\n",
" Downloading jax_lorax-0.1.2-py3-none-any.whl (8.4 kB)\n",
"Requirement already satisfied: jax<0.5.0,>=0.4.6 in /usr/local/lib/python3.10/dist-packages (from jax-lorax) (0.4.10)\n",
"Requirement already satisfied: jaxlib<0.5.0,>=0.4.6 in /usr/local/lib/python3.10/dist-packages (from jax-lorax) (0.4.10+cuda11.cudnn86)\n",
"Requirement already satisfied: ml-dtypes>=0.1.0 in /usr/local/lib/python3.10/dist-packages (from jax<0.5.0,>=0.4.6->jax-lorax) (0.1.0)\n",
"Requirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.10/dist-packages (from jax<0.5.0,>=0.4.6->jax-lorax) (1.22.4)\n",
"Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax<0.5.0,>=0.4.6->jax-lorax) (3.3.0)\n",
"Requirement already satisfied: scipy>=1.7 in /usr/local/lib/python3.10/dist-packages (from jax<0.5.0,>=0.4.6->jax-lorax) (1.10.1)\n",
"Installing collected packages: jax-lorax\n",
"Successfully installed jax-lorax-0.1.2\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting accelerate\n",
" Downloading accelerate-0.20.3-py3-none-any.whl (227 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m227.6/227.6 kB\u001b[0m \u001b[31m18.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate) (1.22.4)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (23.1)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate) (5.9.5)\n",
"Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate) (6.0)\n",
"Requirement already satisfied: torch>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from accelerate) (2.0.1+cu118)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate) (3.12.0)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate) (4.5.0)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate) (1.11.1)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate) (3.1)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate) (3.1.2)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate) (2.0.0)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.6.0->accelerate) (3.25.2)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.6.0->accelerate) (16.0.5)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.6.0->accelerate) (2.1.2)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.6.0->accelerate) (1.3.0)\n",
"Installing collected packages: accelerate\n",
"Successfully installed accelerate-0.20.3\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Requirement already satisfied: jax in /usr/local/lib/python3.10/dist-packages (0.4.10)\n",
"Collecting jax\n",
" Downloading jax-0.4.12.tar.gz (1.3 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m46.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
" Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
" Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
"Requirement already satisfied: jaxlib in /usr/local/lib/python3.10/dist-packages (0.4.10+cuda11.cudnn86)\n",
"Collecting jaxlib\n",
" Downloading jaxlib-0.4.12-cp310-cp310-manylinux2014_x86_64.whl (71.4 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m71.4/71.4 MB\u001b[0m \u001b[31m57.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: ml-dtypes>=0.1.0 in /usr/local/lib/python3.10/dist-packages (from jax) (0.1.0)\n",
"Requirement already satisfied: numpy>=1.21 in /usr/local/lib/python3.10/dist-packages (from jax) (1.22.4)\n",
"Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax) (3.3.0)\n",
"Requirement already satisfied: scipy>=1.7 in /usr/local/lib/python3.10/dist-packages (from jax) (1.10.1)\n",
"Building wheels for collected packages: jax\n",
" Building wheel for jax (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for jax: filename=jax-0.4.12-py3-none-any.whl size=1498447 sha256=69a1a5291f5e76bde6b76257dfa972cfee44386b666ef4100cbeb25f8f70229a\n",
" Stored in directory: /tmp/pip-ephem-wheel-cache-_6j78e9k/wheels/e8/48/6d/8fc5366c9f000bd18db799e801d5e41c6a7f55d73fd3038b7e\n",
"Successfully built jax\n",
"Installing collected packages: jaxlib, jax\n",
" Attempting uninstall: jaxlib\n",
" Found existing installation: jaxlib 0.4.10+cuda11.cudnn86\n",
" Uninstalling jaxlib-0.4.10+cuda11.cudnn86:\n",
" Successfully uninstalled jaxlib-0.4.10+cuda11.cudnn86\n",
" Attempting uninstall: jax\n",
" Found existing installation: jax 0.4.10\n",
" Uninstalling jax-0.4.10:\n",
" Successfully uninstalled jax-0.4.10\n",
"Successfully installed jax-0.4.12 jaxlib-0.4.12\n",
"Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"\u001b[31mERROR: Could not find a version that satisfies the requirement bitsandbytes-cuda117==0.26.0 (from versions: 0.26.0.post2)\u001b[0m\u001b[31m\n",
"\u001b[0m\u001b[31mERROR: No matching distribution found for bitsandbytes-cuda117==0.26.0\u001b[0m\u001b[31m\n",
"\u001b[0mLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
"Collecting transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1\n",
" Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.0/7.0 MB\u001b[0m \u001b[31m66.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (3.12.0)\n",
"Collecting huggingface-hub<1.0,>=0.11.0 (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading huggingface_hub-0.15.1-py3-none-any.whl (236 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m236.8/236.8 kB\u001b[0m \u001b[31m24.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.22.4)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (23.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (6.0)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2022.10.31)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2.27.1)\n",
"Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m102.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (4.65.0)\n",
"Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (8.4.0)\n",
"Requirement already satisfied: librosa in /usr/local/lib/python3.10/dist-packages (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (0.10.0.post2)\n",
"Collecting pyctcdecode>=0.4.0 (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading pyctcdecode-0.5.0-py2.py3-none-any.whl (39 kB)\n",
"Collecting phonemizer (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading phonemizer-3.2.1-py3-none-any.whl (90 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m90.6/90.6 kB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting kenlm (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading kenlm-0.1.tar.gz (424 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m425.0/425.0 kB\u001b[0m \u001b[31m36.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
"Collecting sentencepiece!=0.1.92,>=0.1.91 (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m58.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting protobuf<=3.20.2 (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading protobuf-3.20.2-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.1 MB\u001b[0m \u001b[31m44.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting decord==0.6.0 (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading decord-0.6.0-py3-none-manylinux2010_x86_64.whl (13.6 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.6/13.6 MB\u001b[0m \u001b[31m91.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting av==9.2.0 (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading av-9.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (28.8 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m28.8/28.8 MB\u001b[0m \u001b[31m21.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting onnxconverter-common (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading onnxconverter_common-1.13.0-py2.py3-none-any.whl (83 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m83.8/83.8 kB\u001b[0m \u001b[31m10.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting tf2onnx (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading tf2onnx-1.14.0-py3-none-any.whl (451 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m451.2/451.2 kB\u001b[0m \u001b[31m36.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting onnxruntime>=1.4.0 (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading onnxruntime-1.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.9 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.9/5.9 MB\u001b[0m \u001b[31m100.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting onnxruntime-tools>=1.4.2 (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading onnxruntime_tools-1.7.0-py3-none-any.whl (212 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m212.7/212.7 kB\u001b[0m \u001b[31m24.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting ftfy (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading ftfy-6.1.1-py3-none-any.whl (53 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m53.1/53.1 kB\u001b[0m \u001b[31m6.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting deepspeed>=0.8.3 (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading deepspeed-0.9.4.tar.gz (808 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m808.8/808.8 kB\u001b[0m \u001b[31m52.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
"Requirement already satisfied: accelerate>=0.10.0 in /usr/local/lib/python3.10/dist-packages (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (0.20.3)\n",
"Collecting timm (from transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading timm-0.9.2-py3-none-any.whl (2.2 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m95.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.10.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (5.9.5)\n",
"Requirement already satisfied: torch>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.10.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2.0.1+cu118)\n",
"Collecting hjson (from deepspeed>=0.8.3->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading hjson-3.1.0-py3-none-any.whl (54 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m54.0/54.0 kB\u001b[0m \u001b[31m6.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting ninja (from deepspeed>=0.8.3->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading ninja-1.11.1-py2.py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (145 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m146.0/146.0 kB\u001b[0m \u001b[31m17.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: py-cpuinfo in /usr/local/lib/python3.10/dist-packages (from deepspeed>=0.8.3->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (9.0.0)\n",
"Requirement already satisfied: pydantic<2.0.0 in /usr/local/lib/python3.10/dist-packages (from deepspeed>=0.8.3->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.10.7)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2023.4.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.11.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (4.5.0)\n",
"Collecting coloredlogs (from onnxruntime>=1.4.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.0/46.0 kB\u001b[0m \u001b[31m5.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: flatbuffers in /usr/local/lib/python3.10/dist-packages (from onnxruntime>=1.4.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (23.3.3)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from onnxruntime>=1.4.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.11.1)\n",
"Collecting onnx (from onnxruntime-tools>=1.4.2->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading onnx-1.14.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (14.6 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.6/14.6 MB\u001b[0m \u001b[31m91.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting py3nvml (from onnxruntime-tools>=1.4.2->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading py3nvml-0.2.7-py3-none-any.whl (55 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m55.5/55.5 kB\u001b[0m \u001b[31m6.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting pygtrie<3.0,>=2.1 (from pyctcdecode>=0.4.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading pygtrie-2.5.0-py3-none-any.whl (25 kB)\n",
"Collecting hypothesis<7,>=6.14 (from pyctcdecode>=0.4.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading hypothesis-6.78.2-py3-none-any.whl (416 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m416.8/416.8 kB\u001b[0m \u001b[31m33.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: wcwidth>=0.2.5 in /usr/local/lib/python3.10/dist-packages (from ftfy->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (0.2.6)\n",
"Requirement already satisfied: audioread>=2.1.9 in /usr/local/lib/python3.10/dist-packages (from librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (3.0.0)\n",
"Requirement already satisfied: scipy>=1.2.0 in /usr/local/lib/python3.10/dist-packages (from librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.10.1)\n",
"Requirement already satisfied: scikit-learn>=0.20.0 in /usr/local/lib/python3.10/dist-packages (from librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.2.2)\n",
"Requirement already satisfied: joblib>=0.14 in /usr/local/lib/python3.10/dist-packages (from librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.2.0)\n",
"Requirement already satisfied: decorator>=4.3.0 in /usr/local/lib/python3.10/dist-packages (from librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (4.4.2)\n",
"Requirement already satisfied: numba>=0.51.0 in /usr/local/lib/python3.10/dist-packages (from librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (0.56.4)\n",
"Requirement already satisfied: soundfile>=0.12.1 in /usr/local/lib/python3.10/dist-packages (from librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (0.12.1)\n",
"Requirement already satisfied: pooch<1.7,>=1.0 in /usr/local/lib/python3.10/dist-packages (from librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.6.0)\n",
"Requirement already satisfied: soxr>=0.3.2 in /usr/local/lib/python3.10/dist-packages (from librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (0.3.5)\n",
"Requirement already satisfied: lazy-loader>=0.1 in /usr/local/lib/python3.10/dist-packages (from librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (0.2)\n",
"Requirement already satisfied: msgpack>=1.0 in /usr/local/lib/python3.10/dist-packages (from librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.0.5)\n",
"Collecting segments (from phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading segments-2.2.1-py2.py3-none-any.whl (15 kB)\n",
"Requirement already satisfied: attrs>=18.1 in /usr/local/lib/python3.10/dist-packages (from phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (23.1.0)\n",
"Collecting dlinfo (from phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading dlinfo-1.2.1-py3-none-any.whl (3.6 kB)\n",
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.26.15)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2022.12.7)\n",
"Requirement already satisfied: charset-normalizer~=2.0.0 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2.0.12)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (3.4)\n",
"Requirement already satisfied: six in /usr/local/lib/python3.10/dist-packages (from tf2onnx->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.16.0)\n",
"Collecting flatbuffers (from onnxruntime>=1.4.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading flatbuffers-2.0.7-py2.py3-none-any.whl (26 kB)\n",
"Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (from timm->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (0.15.2+cu118)\n",
"Collecting safetensors (from timm->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m76.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: sortedcontainers<3.0.0,>=2.1.0 in /usr/local/lib/python3.10/dist-packages (from hypothesis<7,>=6.14->pyctcdecode>=0.4.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2.4.0)\n",
"Requirement already satisfied: exceptiongroup>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from hypothesis<7,>=6.14->pyctcdecode>=0.4.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.1.1)\n",
"Requirement already satisfied: llvmlite<0.40,>=0.39.0dev0 in /usr/local/lib/python3.10/dist-packages (from numba>=0.51.0->librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (0.39.1)\n",
"Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from numba>=0.51.0->librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (67.7.2)\n",
"Requirement already satisfied: appdirs>=1.3.0 in /usr/local/lib/python3.10/dist-packages (from pooch<1.7,>=1.0->librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.4.4)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.20.0->librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (3.1.0)\n",
"Requirement already satisfied: cffi>=1.0 in /usr/local/lib/python3.10/dist-packages (from soundfile>=0.12.1->librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.15.1)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate>=0.10.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (3.1)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate>=0.10.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (3.1.2)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->accelerate>=0.10.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2.0.0)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.6.0->accelerate>=0.10.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (3.25.2)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.6.0->accelerate>=0.10.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (16.0.5)\n",
"Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime>=1.4.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.8/86.8 kB\u001b[0m \u001b[31m9.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting xmltodict (from py3nvml->onnxruntime-tools>=1.4.2->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading xmltodict-0.13.0-py2.py3-none-any.whl (10.0 kB)\n",
"Collecting clldutils>=1.7.3 (from segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading clldutils-3.19.0-py2.py3-none-any.whl (1.7 MB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.7/1.7 MB\u001b[0m \u001b[31m53.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting csvw>=1.5.6 (from segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading csvw-3.1.3-py2.py3-none-any.whl (56 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.7/56.7 kB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->onnxruntime>=1.4.0->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (1.3.0)\n",
"Requirement already satisfied: pycparser in /usr/local/lib/python3.10/dist-packages (from cffi>=1.0->soundfile>=0.12.1->librosa->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2.21)\n",
"Requirement already satisfied: python-dateutil in /usr/local/lib/python3.10/dist-packages (from clldutils>=1.7.3->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2.8.2)\n",
"Requirement already satisfied: tabulate>=0.7.7 in /usr/local/lib/python3.10/dist-packages (from clldutils>=1.7.3->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (0.8.10)\n",
"Collecting colorlog (from clldutils>=1.7.3->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading colorlog-6.7.0-py2.py3-none-any.whl (11 kB)\n",
"Collecting pylatexenc (from clldutils>=1.7.3->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading pylatexenc-2.10.tar.gz (162 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m162.6/162.6 kB\u001b[0m \u001b[31m17.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
"Requirement already satisfied: markdown in /usr/local/lib/python3.10/dist-packages (from clldutils>=1.7.3->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (3.4.3)\n",
"Requirement already satisfied: lxml in /usr/local/lib/python3.10/dist-packages (from clldutils>=1.7.3->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (4.9.2)\n",
"Requirement already satisfied: markupsafe in /usr/local/lib/python3.10/dist-packages (from clldutils>=1.7.3->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2.1.2)\n",
"Requirement already satisfied: babel in /usr/local/lib/python3.10/dist-packages (from csvw>=1.5.6->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (2.12.1)\n",
"Collecting colorama (from csvw>=1.5.6->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)\n",
"Collecting isodate (from csvw>=1.5.6->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading isodate-0.6.1-py2.py3-none-any.whl (41 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m41.7/41.7 kB\u001b[0m \u001b[31m4.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: jsonschema in /usr/local/lib/python3.10/dist-packages (from csvw>=1.5.6->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (4.3.3)\n",
"Collecting language-tags (from csvw>=1.5.6->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading language_tags-1.2.0-py3-none-any.whl (213 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m213.4/213.4 kB\u001b[0m \u001b[31m21.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting rdflib (from csvw>=1.5.6->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading rdflib-6.3.2-py3-none-any.whl (528 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m528.1/528.1 kB\u001b[0m \u001b[31m30.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting rfc3986<2 (from csvw>=1.5.6->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1)\n",
" Downloading rfc3986-1.5.0-py2.py3-none-any.whl (31 kB)\n",
"Requirement already satisfied: uritemplate>=3.0.0 in /usr/local/lib/python3.10/dist-packages (from csvw>=1.5.6->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (4.1.1)\n",
"Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema->csvw>=1.5.6->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (0.19.3)\n",
"Requirement already satisfied: pyparsing<4,>=2.1.0 in /usr/local/lib/python3.10/dist-packages (from rdflib->csvw>=1.5.6->segments->phonemizer->transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1) (3.0.9)\n",
"Building wheels for collected packages: deepspeed, kenlm, pylatexenc\n",
" Building wheel for deepspeed (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for deepspeed: filename=deepspeed-0.9.4-py3-none-any.whl size=843978 sha256=d1b1f1ff90e1ccdc20f2f3931eff74d83b341f45a48572cb038023d19c875ab1\n",
" Stored in directory: /root/.cache/pip/wheels/2d/ae/38/1d1c49ac8687c5808b3732e3541b6c896459fb8404763eb98b\n",
" Building wheel for kenlm (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for kenlm: filename=kenlm-0.1-cp310-cp310-linux_x86_64.whl size=3003959 sha256=ac0503b6460b64d96d1a828bb322b5ea6ef3f3825947743e5f9fc1f9ed896c9b\n",
" Stored in directory: /root/.cache/pip/wheels/4e/3a/01/9105a071c30781823efbd96a58279c16f948a87cafb1144042\n",
" Building wheel for pylatexenc (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for pylatexenc: filename=pylatexenc-2.10-py3-none-any.whl size=136820 sha256=c46a5f2915e96aaa8aa771710b5ac9f24cd18e381034243353393cb92352f0d8\n",
" Stored in directory: /root/.cache/pip/wheels/d3/31/8b/e09b0386afd80cfc556c00408c9aeea5c35c4d484a9c762fd5\n",
"Successfully built deepspeed kenlm pylatexenc\n",
"Installing collected packages: tokenizers, sentencepiece, safetensors, rfc3986, pylatexenc, pygtrie, ninja, language-tags, kenlm, hjson, flatbuffers, dlinfo, av, xmltodict, protobuf, isodate, hypothesis, humanfriendly, ftfy, decord, colorlog, colorama, rdflib, pyctcdecode, py3nvml, onnx, huggingface-hub, coloredlogs, clldutils, transformers, tf2onnx, onnxruntime-tools, onnxruntime, onnxconverter-common, csvw, segments, phonemizer, timm, deepspeed\n",
" Attempting uninstall: flatbuffers\n",
" Found existing installation: flatbuffers 23.3.3\n",
" Uninstalling flatbuffers-23.3.3:\n",
" Successfully uninstalled flatbuffers-23.3.3\n",
" Attempting uninstall: protobuf\n",
" Found existing installation: protobuf 3.20.3\n",
" Uninstalling protobuf-3.20.3:\n",
" Successfully uninstalled protobuf-3.20.3\n",
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"tensorflow 2.12.0 requires protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<5.0.0dev,>=3.20.3, but you have protobuf 3.20.2 which is incompatible.\n",
"tensorflow-metadata 1.13.1 requires protobuf<5,>=3.20.3, but you have protobuf 3.20.2 which is incompatible.\u001b[0m\u001b[31m\n",
"\u001b[0mSuccessfully installed av-9.2.0 clldutils-3.19.0 colorama-0.4.6 coloredlogs-15.0.1 colorlog-6.7.0 csvw-3.1.3 decord-0.6.0 deepspeed-0.9.4 dlinfo-1.2.1 flatbuffers-2.0.7 ftfy-6.1.1 hjson-3.1.0 huggingface-hub-0.15.1 humanfriendly-10.0 hypothesis-6.78.2 isodate-0.6.1 kenlm-0.1 language-tags-1.2.0 ninja-1.11.1 onnx-1.14.0 onnxconverter-common-1.13.0 onnxruntime-1.15.0 onnxruntime-tools-1.7.0 phonemizer-3.2.1 protobuf-3.20.2 py3nvml-0.2.7 pyctcdecode-0.5.0 pygtrie-2.5.0 pylatexenc-2.10 rdflib-6.3.2 rfc3986-1.5.0 safetensors-0.3.1 segments-2.2.1 sentencepiece-0.1.99 tf2onnx-1.14.0 timm-0.9.2 tokenizers-0.13.3 transformers-4.28.1 xmltodict-0.13.0\n"
]
},
{
"output_type": "display_data",
"data": {
"application/vnd.colab-display-data+json": {
"pip_warning": {
"packages": [
"google"
]
}
}
},
"metadata": {}
}
],
"source": [
"!pip install --upgrade --no-cache-dir git+https://github.com/davisyoshida/jax-gptq.git\n",
"!pip install --upgrade --no-cache-dir jax-lorax\n",
"#!pip install --upgrade --no-cache-dir transformers\n",
"#!pip install --upgrade --no-cache-dir bitsandbytes-cuda110 bitsandbytes\n",
"!pip install --upgrade --no-cache-dir accelerate\n",
"\n",
"!pip install --upgrade --no-cache-dir jax jaxlib\n",
"\n",
"#!pip uninstall --yes bitsandbytes-cuda110 bitsandbytes transformers\n",
"!pip install bitsandbytes-cuda117==0.26.0\n",
"!pip install transformers[audio,deepspeed,ftfy,onnx,sentencepiece,timm,tokenizers,video,vision]==4.28.1"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "75-T_R0Ms9qD",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "2c63e202-443a-4b65-8665-06ecbfe0cac2"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
}
],
"source": [
"from functools import partial\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"import optax\n",
"import torch\n",
"\n",
"import transformers\n",
"from transformers import (\n",
" CONFIG_MAPPING,\n",
" FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,\n",
" AutoConfig,\n",
" AutoTokenizer,\n",
" FlaxAutoModelForCausalLM,\n",
" HfArgumentParser,\n",
" TrainingArguments,\n",
" is_tensorboard_available,\n",
")\n",
"\n",
"from tqdm import trange\n",
"\n",
"import lorax\n",
"import jax_gptq\n",
"\n",
"#gpu = jax.devices('gpu')[0]\n",
"cpu = jax.devices('cpu')[0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GQuDSjz7svdL"
},
"source": [
"## Toy Example\n",
"\n",
"### Model/Data setup\n",
"\n",
"First we'll define an MLP and make some parameters for it:"
]
},
{
"cell_type": "code",
"source": [
"from transformers import LongT5Config, FlaxT5ForConditionalGeneration\n",
"from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer\n",
"\n",
"from transformers import BitsAndBytesConfig\n",
"\n",
"\n",
"nf4_config = BitsAndBytesConfig(\n",
" load_in_4bit=True,\n",
" bnb_4bit_quant_type=\"nf4\",\n",
" bnb_4bit_use_double_quant=True,\n",
" bnb_4bit_compute_dtype=torch.bfloat16\n",
")\n",
"\n",
"# Load the LongT5-XL model with its configuration\n",
"model_id = \"google/long-t5-tglobal-xl\"\n",
"config = LongT5Config.from_pretrained(model_id)\n",
"#model = AutoModelForSeq2SeqLM.from_pretrained(model_id, load_in_4bit=True, device_map=\"auto\")\n",
"model = AutoModelForSeq2SeqLM.from_pretrained(model_id, quantization_config=nf4_config)\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 344,
"referenced_widgets": [
"dda94429aac549438c343dd2edc24e0a",
"ee547e0a59eb422d975f02daaf672629",
"989350afa53349d39342297bfb4eca95",
"9954285510cb4ebab4deb9d6558dd5ed",
"663b81ccdf0c4b02af24b1cd81b7daa6",
"6bf5f6cd0df14197b8da4303b43cb655",
"693fb73b9dc0411eab41c5fede4b611c",
"db0e6b46120844dc8620fbd2a4076d0d",
"9872d010314e492f9a85cbb59edb8cc3",
"1478a80283ff43f193bd9a458a6c8274",
"199682514a0e43b5a13244a818722d90",
"cdb5e7259f57470296128e76fdf5e6ea",
"9b31610503d84238b9e7a2321ff1b10b",
"32d7477699eb461584db4072a7c97052",
"6c4eafec00674c838c628394c7c68faa",
"a63b31cd45b24408b5a9068d1ecd1b4b",
"715506a0e5694e7ca6061ee99bd92382",
"3c4dbdf570eb46b690cd500476020064",
"bdd5160d52de4f9895f8e4dd7f7ec20a",
"c4eee8d4242045bba3d38597dbb414b3",
"2c47cb97a18f4ee89bd945b1ca11347d",
"4a9f0ebd9d7b4192a25ef75250c03deb",
"7d0234ed0d35487e826565df214a2cc9",
"46ad2aad7e474c73be0bfb1fef55a4bd",
"6d811cfd512f4916b83640f9297bc1bd",
"8c416aa8a81447cba3527d287aeb2f35",
"8e073507368745088aa356b23f044f0b",
"902022c949e448eeb8ce3e31d1ef4153",
"8bcfb6b0b56d4046a839c0f9ae4c16b1",
"7de7565229de41149b770fe25714a847",
"52e158b04b8c4858929bf3739f7c78df",
"a12f0daf615547faa649107033a93011",
"3067249f4bc64398906fceea8d1e5934",
"5fe6035fe6e8426580d632bcedc3d19b",
"fe81c32260ca47c7bcd4fd6230ed3c81",
"ec4a1c28bbb54799a2278ee1e1d65383",
"68cc8181fd484af1b9f42b1c9b5f8d75",
"0d20b32672e34c7295634090e6cebf27",
"24f189322a134a20900084de1e08c376",
"aae1c621a5b94960ace82d52253ccaf7",
"ff2779f06f954e35a5eec9a0a02684f1",
"c00fb1fde3f74d82ac6b3b5bc19a9796",
"d50444fed9a746ff8281239c4c1a3058",
"1110d48d7a664208bcd57f3c67aecccd",
"c22c1289367046babeb0c9409c181f8b",
"eb3f2e923426469e9b92a51187f60f60",
"83784cb918974e4c8664f78561ea8cf8",
"008aa06de0fd4ec789d3d42e83dbd16e",
"33217d952d544284937852eeb17323fc",
"c83d30547b08413f90300abdf318b123",
"2b8df3f5200e4c9fa59f057c4586a700",
"8fc75f12da864868a4d0ab9c1949d1ac",
"b00196f8517c415ea5388bf1a72ffb37",
"09ae4873193f41baa133879a8d3f5905",
"bd9db9a757db4ccfa68154524ac95dbb",
"410478d1c6164c679702f1b959ec6dc6",
"34ae822850454b4b8f92dae185943fe4",
"44bd417858fb4982bd97cc07b3f03b5d",
"84ddc0bd3bfd4df9a300a0ea01dafc06",
"60e20079809641db96f76aa61980ac74",
"a82fed6b81274263a1c88e41063356f4",
"14c39b554476480d8ed17a4188ffc261",
"1b9abcbc8e054a069d9bf0710329928b",
"d85f0a6e82b648b0899e7006b5a7e740",
"df82eb9aaa354a4395a5304ae038a828",
"96b5cf9d5b7846bfa61523ee1ec4db69",
"3458c2feac6b41ebbf6ed872d9112df0",
"d8cf12f088a94b4da262d8d2c134e89f",
"2973bec2ccff49168ba84ddc1f326e8a",
"58230fe17b7840c983a59a807a051f2c",
"b4687344280545029c6505de0d277b97",
"910e96c5ff324caba366ede006d164ab",
"d8138fd266e44ab3963a274ef7018b0c",
"0043c749b6ad4d27a251abbf5d3644ed",
"ccf2b81c4a774692980b024885d00872",
"47302718457149c9bd98b2ef0b702d52",
"83d6be00be514667a08472ba48a7d42f",
"6d2432a98bf84a77a215f03e69a8a79f",
"9ecd6788ffb04d37ae547f518246fc6b",
"da31333b1c844b56830916b794ce9833",
"a90c9881a7d0420cb7e230f2b19f25f4",
"7fc0552e04b247f5a8b5c8dc7ebfbfe1",
"5d21f67dc7aa4297b7115299c2589da7",
"5562aa475ee846a89442b793da39c802",
"46fd98456a604139a948a97056fb6842",
"51bebf1ec8fe4e798bbba4bbf158bd6d",
"2f2b07d552364493965c7811042d73c5",
"fab5d066dc794d1dbb17b282081b1d92",
"d20b4ddee92c4d22a81bd819e8314fa3",
"556d33b982cc4989a03fc4b29bc7e84a",
"03675f31e76a4135a2ae239f3ee42655",
"8231b974ae924b68959d606e073d66ad",
"4bfd7fde14514f41bea4edcb878cfc3c",
"d439f8d48f754ed1976392cadfaac95f",
"86b9a19a4ecf4e52974e1dc84335aca2",
"1b6c1e021fa64948ae533a11f029bd50",
"71b903a4156149bbb714e1f3841f13b2",
"a011aa10bd374418b030f7cd2f602255",
"e52faa2810604688844ddf624b0a1400"
]
},
"id": "YKcA0xmzRIas",
"outputId": "1d09422a-3405-40a3-c1e2-f3fed346f513"
},
"execution_count": 4,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading (…)lve/main/config.json: 0%| | 0.00/896 [00:00<?, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "dda94429aac549438c343dd2edc24e0a"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"[2023-06-15 06:49:37,016] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading (…)model.bin.index.json: 0%| | 0.00/55.4k [00:00<?, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "cdb5e7259f57470296128e76fdf5e6ea"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading shards: 0%| | 0/2 [00:00<?, ?it/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "7d0234ed0d35487e826565df214a2cc9"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading (…)l-00001-of-00002.bin: 0%| | 0.00/9.45G [00:00<?, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "5fe6035fe6e8426580d632bcedc3d19b"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading (…)l-00002-of-00002.bin: 0%| | 0.00/1.95G [00:00<?, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "c22c1289367046babeb0c9409c181f8b"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"WARNING:accelerate.utils.modeling:The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Loading checkpoint shards: 0%| | 0/2 [00:00<?, ?it/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "410478d1c6164c679702f1b959ec6dc6"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading (…)neration_config.json: 0%| | 0.00/147 [00:00<?, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "3458c2feac6b41ebbf6ed872d9112df0"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading spiece.model: 0%| | 0.00/792k [00:00<?, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "6d2432a98bf84a77a215f03e69a8a79f"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"Downloading (…)/main/tokenizer.json: 0%| | 0.00/1.39M [00:00<?, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "d20b4ddee92c4d22a81bd819e8314fa3"
}
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ypIb3RHPM77-",
"outputId": "4fa2a63a-6f41-41ac-d416-0cd6979a29b2"
},
"execution_count": 8,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mounted at /content/drive\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Save the quantized model to disc for future use in TPU\n",
"!mkdir -p /content/checkpoints\n",
"\n",
"model.save_pretrained(\n",
" \"/content/checkpoints\",\n",
" commit_message=f\"Saving weights and logs\",\n",
")"
],
"metadata": {
"id": "McT8hhNNlEXq"
},
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!cp -rf /content/checkpoints/ /content/drive/MyDrive/"
],
"metadata": {
"id": "5qQ0iUToOQo8"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"!zip -r longT5-xl-quantized.zip /content/checkpoints/"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3nwB94DunQ3k",
"outputId": "3252d182-db7d-4926-8b97-59a9da041f14"
},
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" adding: content/checkpoints/ (stored 0%)\n",
" adding: content/checkpoints/pytorch_model.bin.index.json (deflated 96%)\n",
" adding: content/checkpoints/pytorch_model-00002-of-00002.bin (deflated 7%)\n",
" adding: content/checkpoints/config.json (deflated 48%)\n",
" adding: content/checkpoints/pytorch_model-00001-of-00002.bin (deflated 7%)\n",
" adding: content/checkpoints/generation_config.json (deflated 29%)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Reference : https://github.com/davisyoshida/lorax/blob/master/examples/huggingface_gpt2.py\n",
"\n",
"import warnings\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import optax\n",
"from transformers import FlaxGPT2LMHeadModel\n",
"\n",
"from lorax import simple_spec, init_lora, lora, LORA_FULL, merge_params\n",
"\n",
"def main():\n",
" #model = FlaxGPT2LMHeadModel.from_pretrained('gpt2')\n",
"\n",
" # Wrap the forward pass in so that lorax knows which params to LoRA-fy (it only does the first argument by default)\n",
" @lora\n",
" def lora_forward(params, input_ids):\n",
" return model(input_ids, params=params)\n",
"\n",
" # This function defines a spec which tells lorax how each parameter should be handled\n",
" def decision_fn(path, param):\n",
" if 'embedding' in path:\n",
" print(f'Fully finetuning param {path}')\n",
" return LORA_FULL\n",
" dim = 32\n",
" print(f'Using LoRA with dim={dim} for param {path}')\n",
" return dim\n",
"\n",
" # Create a pytree with the same shape as params indicating how each parameter should be handled\n",
" params, dropout_rng, *_ = model.parameters()\n",
"\n",
" # Convert the generator object to a list of arrays\n",
" params = list(params)\n",
"\n",
" # Convert the list of tensors to a tuple of tensors\n",
" params = tuple(params)\n",
" lora_spec = simple_spec(params, decision_fn=decision_fn, tune_vectors=True)\n",
"\n",
" # Split the parameters up into tunable and frozen ones, and initialize a pair of LoRA matrices for each parameter\n",
" # which had a spec value other than LORA_FULL or LORA_FREEZE\n",
" freeze_params, tune_params = init_lora(model.parameters(), lora_spec, jax.random.PRNGKey(0))\n",
"\n",
" optimizer = optax.adamw(learning_rate=1e-4, weight_decay=1e-4)\n",
"\n",
" # Make sure to only pass the tunable parameters to the optimizer\n",
" opt_state = optimizer.init(tune_params)\n",
"\n",
" # The loss function should take the tunable and frozen params separately so\n",
" # you can differentiate w.r.t. the tunable ones only\n",
" def loss_fn(tunable_params, frozen_params, batch):\n",
" input_ids = batch[:, :-1]\n",
" logits = lora_forward((frozen_params, tunable_params), input_ids).logits\n",
"\n",
" logprobs = jax.nn.log_softmax(logits)\n",
" target_logprobs = jnp.take_along_axis(logprobs, batch[:, 1:, None], axis=-1)\n",
" return -jnp.mean(target_logprobs)\n",
"\n",
" @jax.jit\n",
" def update_fn(tunable_params, frozen_params, opt_state, batch):\n",
" loss, grads = jax.value_and_grad(loss_fn)(tunable_params, frozen_params, batch)\n",
" updates, new_opt_state = optimizer.update(grads, opt_state, params=tunable_params)\n",
"\n",
" new_tunable_params = optax.apply_updates(tunable_params, updates)\n",
" return new_tunable_params, new_opt_state, loss\n",
"\n",
" # Train on a dummy batch to demo loss going down\n",
" example_data = jax.random.randint(jax.random.PRNGKey(0), (4, 128), 0, 50257)\n",
" for _ in range(100):\n",
" tune_params, opt_state, loss = update_fn(tune_params, freeze_params, opt_state, example_data)\n",
" print(loss)\n",
"\n",
" final_predictions = lora_forward((freeze_params, tune_params), example_data).logits\n",
" merged_params = merge_params(freeze_params, tune_params)\n",
"\n",
" orig_model_predictions = model(example_data, params=merged_params).logits\n",
"\n",
" gap = jnp.max(jnp.abs(final_predictions - orig_model_predictions))\n",
" print(f'Max prediction gap: {gap:.3e}')\n",
"\n",
"if __name__ == '__main__':\n",
" main()"
],
"metadata": {
"id": "R4arRIjcL_F2",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 788
},
"outputId": "95b74d2b-b819-441a-9f95-9f55b1843e5d"
},
"execution_count": 7,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"\u001b[31m╭─\u001b[0m\u001b[31m──────────────────────────────\u001b[0m\u001b[31m \u001b[0m\u001b[1;31mTraceback \u001b[0m\u001b[1;2;31m(most recent call last)\u001b[0m\u001b[31m \u001b[0m\u001b[31m───────────────────────────────\u001b[0m\u001b[31m─╮\u001b[0m\n",
"\u001b[31m│\u001b[0m in \u001b[92m<cell line: 80>\u001b[0m:\u001b[94m81\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m in \u001b[92mmain\u001b[0m:\u001b[94m41\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2;33m/usr/local/lib/python3.10/dist-packages/lorax/\u001b[0m\u001b[1;33mhelpers.py\u001b[0m:\u001b[94m42\u001b[0m in \u001b[92minit_lora\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m39 \u001b[0m\u001b[2m│ \u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m40 \u001b[0m\u001b[2m│ \u001b[0m\u001b[94mreturn\u001b[0m ( \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m41 \u001b[0m\u001b[2m│ │ \u001b[0mjax.tree_map(freeze_getter, param_tree, spec, is_leaf=is_leaf), \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m42 \u001b[2m│ │ \u001b[0mjax.tree_util.tree_map_with_path(tune_getter, param_tree, spec, is_leaf=is_leaf) \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m43 \u001b[0m\u001b[2m│ \u001b[0m) \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m44 \u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m45 \u001b[0m\u001b[94mdef\u001b[0m \u001b[92msimple_spec\u001b[0m(params, decision_fn=\u001b[94mNone\u001b[0m, tune_vectors=\u001b[94mFalse\u001b[0m, is_leaf=\u001b[94mNone\u001b[0m): \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2;33m/usr/local/lib/python3.10/dist-packages/jax/_src/\u001b[0m\u001b[1;33mtree_util.py\u001b[0m:\u001b[94m788\u001b[0m in \u001b[92mtree_map_with_path\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m785 \u001b[0m\u001b[2m \u001b[0mkeypath_leaves, treedef = tree_flatten_with_path(tree, is_leaf) \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m786 \u001b[0m\u001b[2m \u001b[0mkeypath_leaves = \u001b[96mlist\u001b[0m(\u001b[96mzip\u001b[0m(*keypath_leaves)) \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m787 \u001b[0m\u001b[2m \u001b[0mall_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) \u001b[94mfor\u001b[0m r \u001b[95min\u001b[0m rest] \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m788 \u001b[2m \u001b[0m\u001b[94mreturn\u001b[0m treedef.unflatten(f(*xs) \u001b[94mfor\u001b[0m xs \u001b[95min\u001b[0m \u001b[96mzip\u001b[0m(*all_keypath_leaves)) \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m789 \u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m790 \u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m791 \u001b[0m\u001b[94mdef\u001b[0m \u001b[92m_child_keys\u001b[0m(pytree: Any) -> KeyPath: \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2;33m/usr/local/lib/python3.10/dist-packages/jax/_src/\u001b[0m\u001b[1;33mtree_util.py\u001b[0m:\u001b[94m788\u001b[0m in \u001b[92m<genexpr>\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m785 \u001b[0m\u001b[2m \u001b[0mkeypath_leaves, treedef = tree_flatten_with_path(tree, is_leaf) \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m786 \u001b[0m\u001b[2m \u001b[0mkeypath_leaves = \u001b[96mlist\u001b[0m(\u001b[96mzip\u001b[0m(*keypath_leaves)) \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m787 \u001b[0m\u001b[2m \u001b[0mall_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) \u001b[94mfor\u001b[0m r \u001b[95min\u001b[0m rest] \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m788 \u001b[2m \u001b[0m\u001b[94mreturn\u001b[0m treedef.unflatten(f(*xs) \u001b[94mfor\u001b[0m xs \u001b[95min\u001b[0m \u001b[96mzip\u001b[0m(*all_keypath_leaves)) \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m789 \u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m790 \u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m791 \u001b[0m\u001b[94mdef\u001b[0m \u001b[92m_child_keys\u001b[0m(pytree: Any) -> KeyPath: \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2;33m/usr/local/lib/python3.10/dist-packages/lorax/\u001b[0m\u001b[1;33mhelpers.py\u001b[0m:\u001b[94m20\u001b[0m in \u001b[92mtune_getter\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m17 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m spec_val == LORA_FULL: \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m18 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mreturn\u001b[0m param \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m19 \u001b[0m\u001b[2m│ │ \u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[31m❱ \u001b[0m20 \u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mlen\u001b[0m(param.shape) == \u001b[94m1\u001b[0m: \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m21 \u001b[0m\u001b[2m│ │ │ \u001b[0m\u001b[94mraise\u001b[0m \u001b[96mValueError\u001b[0m(\u001b[33mf\u001b[0m\u001b[33m'\u001b[0m\u001b[33mVectors must either be frozen or fully tuned, but got spe\u001b[0m \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m22 \u001b[0m\u001b[2m│ │ \u001b[0m\u001b[94mif\u001b[0m \u001b[96mlen\u001b[0m(param.shape) == \u001b[94m2\u001b[0m: \u001b[31m│\u001b[0m\n",
"\u001b[31m│\u001b[0m \u001b[2m23 \u001b[0m\u001b[2m│ │ │ \u001b[0mb_dim, a_dim = param.shape \u001b[31m│\u001b[0m\n",
"\u001b[31m╰──────────────────────────────────────────────────────────────────────────────────────────────────╯\u001b[0m\n",
"\u001b[1;91mAttributeError: \u001b[0m\u001b[32m'generator'\u001b[0m object has no attribute \u001b[32m'shape'\u001b[0m\n"
],
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800000; text-decoration-color: #800000\">╭─────────────────────────────── </span><span style=\"color: #800000; text-decoration-color: #800000; font-weight: bold\">Traceback </span><span style=\"color: #bf7f7f; text-decoration-color: #bf7f7f; font-weight: bold\">(most recent call last)</span><span style=\"color: #800000; text-decoration-color: #800000\"> ────────────────────────────────╮</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">&lt;cell line: 80&gt;</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">81</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">main</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">41</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #bfbf7f; text-decoration-color: #bfbf7f\">/usr/local/lib/python3.10/dist-packages/lorax/</span><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">helpers.py</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">42</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">init_lora</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">39 │ </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">40 │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">return</span> ( <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">41 │ │ </span>jax.tree_map(freeze_getter, param_tree, spec, is_leaf=is_leaf), <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span>42 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ │ </span>jax.tree_util.tree_map_with_path(tune_getter, param_tree, spec, is_leaf=is_leaf) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">43 │ </span>) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">44 </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">45 </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">def</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00\">simple_spec</span>(params, decision_fn=<span style=\"color: #0000ff; text-decoration-color: #0000ff\">None</span>, tune_vectors=<span style=\"color: #0000ff; text-decoration-color: #0000ff\">False</span>, is_leaf=<span style=\"color: #0000ff; text-decoration-color: #0000ff\">None</span>): <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #bfbf7f; text-decoration-color: #bfbf7f\">/usr/local/lib/python3.10/dist-packages/jax/_src/</span><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">tree_util.py</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">788</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">tree_map_with_path</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">785 </span>keypath_leaves, treedef = tree_flatten_with_path(tree, is_leaf) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">786 </span>keypath_leaves = <span style=\"color: #00ffff; text-decoration-color: #00ffff\">list</span>(<span style=\"color: #00ffff; text-decoration-color: #00ffff\">zip</span>(*keypath_leaves)) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">787 </span>all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) <span style=\"color: #0000ff; text-decoration-color: #0000ff\">for</span> r <span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">in</span> rest] <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span>788 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">return</span> treedef.unflatten(f(*xs) <span style=\"color: #0000ff; text-decoration-color: #0000ff\">for</span> xs <span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">in</span> <span style=\"color: #00ffff; text-decoration-color: #00ffff\">zip</span>(*all_keypath_leaves)) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">789 </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">790 </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">791 </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">def</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00\">_child_keys</span>(pytree: Any) -&gt; KeyPath: <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #bfbf7f; text-decoration-color: #bfbf7f\">/usr/local/lib/python3.10/dist-packages/jax/_src/</span><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">tree_util.py</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">788</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">&lt;genexpr&gt;</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">785 </span>keypath_leaves, treedef = tree_flatten_with_path(tree, is_leaf) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">786 </span>keypath_leaves = <span style=\"color: #00ffff; text-decoration-color: #00ffff\">list</span>(<span style=\"color: #00ffff; text-decoration-color: #00ffff\">zip</span>(*keypath_leaves)) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">787 </span>all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) <span style=\"color: #0000ff; text-decoration-color: #0000ff\">for</span> r <span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">in</span> rest] <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span>788 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">return</span> treedef.unflatten(f(*xs) <span style=\"color: #0000ff; text-decoration-color: #0000ff\">for</span> xs <span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">in</span> <span style=\"color: #00ffff; text-decoration-color: #00ffff\">zip</span>(*all_keypath_leaves)) <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">789 </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">790 </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">791 </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">def</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00\">_child_keys</span>(pytree: Any) -&gt; KeyPath: <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #bfbf7f; text-decoration-color: #bfbf7f\">/usr/local/lib/python3.10/dist-packages/lorax/</span><span style=\"color: #808000; text-decoration-color: #808000; font-weight: bold\">helpers.py</span>:<span style=\"color: #0000ff; text-decoration-color: #0000ff\">20</span> in <span style=\"color: #00ff00; text-decoration-color: #00ff00\">tune_getter</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">17 │ │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">if</span> spec_val == LORA_FULL: <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">18 │ │ │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">return</span> param <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">19 │ │ </span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #800000; text-decoration-color: #800000\">❱ </span>20 <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">│ │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">if</span> <span style=\"color: #00ffff; text-decoration-color: #00ffff\">len</span>(param.shape) == <span style=\"color: #0000ff; text-decoration-color: #0000ff\">1</span>: <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">21 │ │ │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">raise</span> <span style=\"color: #00ffff; text-decoration-color: #00ffff\">ValueError</span>(<span style=\"color: #808000; text-decoration-color: #808000\">f'Vectors must either be frozen or fully tuned, but got spe</span> <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">22 │ │ </span><span style=\"color: #0000ff; text-decoration-color: #0000ff\">if</span> <span style=\"color: #00ffff; text-decoration-color: #00ffff\">len</span>(param.shape) == <span style=\"color: #0000ff; text-decoration-color: #0000ff\">2</span>: <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">│</span> <span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\">23 │ │ │ </span>b_dim, a_dim = param.shape <span style=\"color: #800000; text-decoration-color: #800000\">│</span>\n",
"<span style=\"color: #800000; text-decoration-color: #800000\">╰──────────────────────────────────────────────────────────────────────────────────────────────────╯</span>\n",
"<span style=\"color: #ff0000; text-decoration-color: #ff0000; font-weight: bold\">AttributeError: </span><span style=\"color: #008000; text-decoration-color: #008000\">'generator'</span> object has no attribute <span style=\"color: #008000; text-decoration-color: #008000\">'shape'</span>\n",
"</pre>\n"
]
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Djyo_reAs26R"
},
"outputs": [],
"source": [
"'''\n",
"# Initialize the model parameters using JAX's PRNG key\n",
"rng_key = jax.random.PRNGKey(0)\n",
"input_ids = jnp.array([[1, 2, 3, 4, 5]])\n",
"decoder_input_ids = jnp.array([[1, 2, 3, 4, 5]])\n",
"params = model.parameters()\n",
"'''\n",
"\n",
"# Modify my_model to use the LongT5-XL model instead of the custom model defined earlier\n",
"def my_model(params, x):\n",
" logits = model(input_ids=x, params=params, train=True).logits\n",
" return jnp.mean(logits)\n",
"\n",
"# Define a loss function for the LongT5-XL model\n",
"@jax.jit\n",
"def compute_loss(params, input_ids, decoder_input_ids, labels):\n",
" logits = model(\n",
" input_ids=input_ids,\n",
" decoder_input_ids=decoder_input_ids,\n",
" params=params,\n",
" train=True\n",
" ).logits\n",
"\n",
"# Transform the loss function to get the gradients\n",
"grad_fn = jax.value_and_grad(compute_loss)\n",
"\n",
"# Define an optimizer to update the parameters using the gradients\n",
"optimizer = optax.adam(learning_rate=1e-3)\n",
"\n",
"# Define a train step function which combines the loss function and optimizer update, does the forward and backward pass, and returns the updated parameters\n",
"@jax.jit\n",
"def train_step(params, x, y, optimizer):\n",
" grads, loss = grad_fn(params, x, y)\n",
" updates, optimizer_state = optimizer.update(grads, optimizer_state)\n",
" new_params = optax.apply_updates(params, updates)\n",
" return new_params, loss, optimizer_state\n",
"\n",
"# Define a batch generator function using get_batches() from stackoverflow.com\n",
"def generate_batch(batch_size, rng, DIM=512):\n",
" # Generate a batch of input-output pairs\n",
" X_batch = jax.random.normal(rng, (batch_size, DIM))\n",
" Y_batch = jax.random.randint(rng, (batch_size,), 0, 2, dtype=jnp.int32)\n",
"\n",
" return X_batch, Y_batch\n",
"\n",
"# Initialize the optimizer state and the PRNG key\n",
"optimizer_state = optimizer.init(params)\n",
"rng = jax.random.PRNGKey(0)\n",
"\n",
"# Train the model\n",
"num_steps = 50\n",
"batch_size = 4\n",
"\n",
"for i in range(num_steps):\n",
" # Generate a batch of input-output pairs\n",
" x_batch, y_batch = generate_batch(batch_size, rng)\n",
"\n",
" # Update the parameters and optimizer state\n",
" params, loss, optimizer_state = train_step(params, x_batch, y_batch, optimizer_state)\n",
"\n",
" # Print the loss every 10 steps\n",
" if i % 10 == 0:\n",
" print(f'Step {i}, Loss: {loss}')\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RlCLAmjBvhnA"
},
"source": [
"GPT-Q needs input data for quantization. For an actual model we'd use real data but here we'll just make some random inputs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "6govTMOZvgSC"
},
"outputs": [],
"source": [
"quant_data = [jax.random.normal(key, (batch_size, DIM)) for key in jax.random.split(data_key, 64)]\n",
"\n",
"# We'll save an output for later comparison since the quantization process will delete the original params\n",
"original_output = my_model(params, quant_data[0])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Rjdb3h46vtsi"
},
"source": [
"### Run GPT-Q to get the quantized weights\n",
"That's all for the setup, we can now just run GPT-Q (without any changes to the original model code):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"id": "L1Mw9ZLpvrLa"
},
"outputs": [],
"source": [
"# Note that this may free the buffers associated with some or all of the parameters and the data to save VRAM\n",
"# I'd also recommend you put the params on the CPU, since `quantize()` will move the params to th GPU when necessary\n",
"quantized_params = jax_gptq.quantize(my_model, params, quant_data)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2NhVv8egwDQu"
},
"source": [
"The matrices have been quantized but the biases have been left alone:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "bWwXzTJyubbH"
},
"outputs": [],
"source": [
" print(f'W type: {type(quantized_params[0][\"w\"])}')\n",
" print(f'B type: {type(quantized_params[0][\"b\"])}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QwYLTr6WwapB"
},
"source": [
"**Note**: The quantization procedure depends on the parameter being used in a matrix multiplication. Currently JAX-GPTQ supports general dot operations (including ones using tensors with any number of dimensions larger than 1), and convolutions with kernels of spatial size 1.\n",
"\n",
"### Applying the quantized weights\n",
"We can now run the quantized model without any code changes. All that's necessary is using `jax_gptq.use_quantized` to transform the function so it knows how to handle `QuantizedMatrix` values."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I6aLdXqawQFs"
},
"outputs": [],
"source": [
"quantized_params = jax.device_put(quantized_params, gpu) # Move the params to the GPU\n",
"\n",
"# Originally:\n",
"# my_model(params, inputs)\n",
"# After:\n",
"# jax_gptq(my_model)(params, inputs)\n",
"quant_output = jax_gptq.use_quantized(my_model)(quantized_params, quant_data[0])\n",
"\n",
"print(f'Output of quantized network: {quant_output:.3e}')\n",
"print(f'Original output: {original_output:.3e}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1vXkTTctx7Vo"
},
"source": [
"### Train with LoRA\n",
"\n",
"Now that we've compressed our model to 4-bits (and change) per parameter, we can add full precision LoRA parameters for finetuning.\n",
"\n",
"The one gotcha about combining the two is that Lorax doesn't know that QuantizedMatrix values are pytree leaves, so you need to give the Lorax functions an `is_leaf` predicate."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "l95MirHdzNo9"
},
"source": [
"**Initialization:** The `init_lora` function expects a pytree describing which parameters should get LoRA parameters, which should be fully trained, and which should be left frozen. `lorax.simple_spec` is a helper function for making these specs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HKkhcjx9zJy6"
},
"outputs": [],
"source": [
"def is_leaf(x):\n",
" return isinstance(x, jax_gptq.QuantizedMatrix)\n",
"\n",
"lora_spec = lorax.simple_spec(\n",
" params=quantized_params,\n",
" decision_fn=lambda pytree_path, arr: 4, # Just ignore the inputs and specify an inner rank of 4 for all params\n",
" tune_vectors=False, # Tell Lorax to put all the biases in the frozen params tree instead of the tunable params tree\n",
" is_leaf=is_leaf\n",
")\n",
"\n",
"# Lorax splits the parameters into two pytrees:\n",
"# freeze_params: Anything which received the value lorax.LORA_FREEZE in the spec\n",
"# train_params: Pairs of two narrow matrices for values which got positive integers as spec values, or the full parameter if the value lorax.LORA_FULL was in the spec\n",
"freeze_params, train_params = lorax.init_lora(quantized_params, lora_spec, jax.random.PRNGKey(1234), is_leaf=is_leaf)\n",
"\n",
"def merge_quantized_with_lora(q_params, lora_freeze):\n",
" return jax.tree_map(\n",
" lambda quant, from_lora: quant if isinstance(quant, jax_gptq.QuantizedMatrix) else from_lora,\n",
" q_params,\n",
" lora_freeze,\n",
" is_leaf=lambda x: isinstance(x, jax_gptq.QuantizedMatrix) # Tell tree_map to treat QuantizedMatrix as a single value instead of a non-leaf node\n",
" )\n",
"# Now we put the actual quantized params back\n",
"#freeze_params = merge_quantized_with_lora(quantized_params, freeze_params)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-ebT9GXp16v4"
},
"source": [
"The `lorax.lora` transform converts a function from expecting a single pytree in the specified argument to expecting a tuple of two pytrees. It composes with other JAX transforms such as `jax_gptq.use_quantized`, so we can use both at once with no modifications to our model code."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1XjjuQcq1oSq"
},
"outputs": [],
"source": [
"combined_params = (freeze_params, train_params)\n",
"\n",
"my_model_with_lora_and_quantized_weights = jax_gptq.use_quantized(lorax.lora(my_model))\n",
"\n",
"# The differences from the original `my_model` function are:\n",
"# 1. The params argument now expects a tuple of (frozen_params, trainable_params)\n",
"# 2. It knows how to compute with quantized weights\n",
"quantized_plus_lorax_output = my_model_with_lora_and_quantized_weights(combined_params, quant_data[0])\n",
"\n",
"print(f'GPTQ + Lorax output: {quantized_plus_lorax_output:.3e}')\n",
"print(f'GPTQ only: {quant_output:.3e}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aIywP5qQ3KEH"
},
"source": [
"The above values are identical since LoRA initializes one of each pair of matrices as zeros.\n",
"\n",
"Let's look at the size of each pytree:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nqQwBPjh2ttl"
},
"outputs": [],
"source": [
"count_params = partial(jax.tree_util.tree_reduce,\n",
" lambda acc, param: acc + (param.size if isinstance(param, jnp.ndarray) else 0),\n",
" initializer=0\n",
")\n",
"\n",
"print(f'{count_params(freeze_params):.3e} frozen params')\n",
"print(f'{count_params(train_params):.3e} trainable params')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0CJ58F005g-c"
},
"source": [
"Training with this function is no different from any other JAX function, just make sure to only differentiate your loss with respect to the trainable parameters only. (See the next section for an example)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "m_lDOLnw5zoC"
},
"source": [
"## GPT-Q-ing + LoRA-ing HuggingFace's Flax GPT-2\n",
"I developed these transforms for use with my Haiku models, but since all JAX models are pure functions at the end of the day, it shouldn't matter what framework you use. Lorax supports matmuls and other matmul-like operations such as embedding lookups and 1-D convs.\n",
"\n",
"This is a minimal example of applying the combination to `gpt2-medium`, but it's basically model agnostic.\n",
"\n",
"First let's get the model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "czS5kDWO6XTv"
},
"outputs": [],
"source": [
"from transformers import AutoTokenizer, FlaxAutoModelForCausalLM"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VnfmpQ6f6Yal"
},
"outputs": [],
"source": [
"model_name = 'gpt2-medium'\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"model, params = FlaxAutoModelForCausalLM.from_pretrained(model_name, _do_init=False)\n",
"params = jax.device_put(params, cpu)\n",
"\n",
"# Because the embedding table is reused as the output linear layer, it'll get quantized at the end of the process, but that will seriously screw up the embedding lookup step, so we'll just save it for later here\n",
"orig_embedding_table = np.asarray(params['transformer']['wte']['embedding'])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "evCyWa787m_N"
},
"source": [
"The GPT-Q paper used real text data for quantization, but for this demo I'll just generate some random values."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ao_vTWAf7Tw-"
},
"outputs": [],
"source": [
"QUANT_BATCH_SIZE = 4\n",
"QUANT_EXAMPLE_LENGTH = 64 # I'd recommend making this bigger, but needs to be small to not crash colab\n",
"\n",
"quantization_data = []\n",
"key = jax.random.PRNGKey(0)\n",
"for _ in range(32):\n",
" batch = jax.random.randint(key, (QUANT_BATCH_SIZE, QUANT_EXAMPLE_LENGTH), 0, 50256)\n",
" quantization_data.append(batch)\n",
" key, = jax.random.split(key, 1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0x_pT_fT8Co8"
},
"source": [
"HuggingFace's models don't have quite the right call signature, so we'll make a wrapper which takes (params, inputs) as an argument:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true,
"id": "yddz4OUN8Bvt"
},
"outputs": [],
"source": [
"def apply_model(params, batch):\n",
" return model(batch, params=params)\n",
"\n",
"quantized_params = jax_gptq.quantize(apply_model, params, quantization_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ehblO3I98akJ"
},
"outputs": [],
"source": [
"# Replace the quantized embedding table with the original one\n",
"quantized_params['transformer']['wte']['embedding'] = jnp.asarray(orig_embedding_table)\n",
"quantized_params = jax.device_put(quantized_params, gpu)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WYiCG5fE9yKT"
},
"source": [
"### Finetuning GPT-2 with Lorax\n",
"\n",
"Same as [above](https://colab.research.google.com/drive/18rkULbWqk7mNZDx7Scx-JS3p_s45mgok#scrollTo=HKkhcjx9zJy6&line=3&uniqifier=1), we get the original param structure to tell Lorax how to initialize the LoRA params, then merge the quantized params back in after."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FKS_dfll93sO"
},
"outputs": [],
"source": [
"# Get pre-quantization param tree (some nodes will just be abstract values)\n",
"orig_params_or_shapes = jax_gptq.utils.quantized_params_to_shaped_arrays(quantized_params)\n",
"\n",
"# Tell Lorax which leaves should be frozen/fully trained/LoRA trained\n",
"spec = lorax.simple_spec(\n",
" orig_params_or_shapes,\n",
" lambda path, arr: 16 if any(pattern in path for pattern in ['c_attn', 'mlp']) else lorax.LORA_FREEZE,\n",
" tune_vectors=True\n",
")\n",
"\n",
"# Initialize parameters\n",
"key, init_key = jax.random.split(key)\n",
"freeze_params, train_params = lorax.init_lora(\n",
" orig_params_or_shapes,\n",
" spec,\n",
" init_key\n",
")\n",
"\n",
"# Put the quantized params back into the frozen param tree\n",
"freeze_params = merge_quantized_with_lora(quantized_params, freeze_params)\n",
"combined_params = freeze_params, train_params"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T8bJwqN2Bfqh"
},
"source": [
"Now we can just transform the `apply_model` function and it will use both LoRA and 4-bit quantized parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "glARn7Z0BX4g"
},
"outputs": [],
"source": [
"quantized_plus_lora_fn = jax_gptq.use_quantized(lorax.lora(apply_model))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Y1G-d0yDBn8y"
},
"source": [
"### Training\n",
"Training isn't actually any different from normal training, since you can just think of `freeze_params` as being a constant argument, but here's a demo for completness.\n",
"\n",
"First I'll define a toy corpus which demonstrates Alan's love of cats and Grace's dislike of them."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "I3fdjSioBvDO"
},
"outputs": [],
"source": [
"CATS = ['lions', 'tigers', 'cheetahs', 'cats', 'ocelots', 'kittens']\n",
"DOGS = ['wolves', 'dogs', 'coyotes', 'huskies', 'poodles', 'puppies']\n",
"\n",
"CAT_LOVER = 'Alan'\n",
"DOG_LOVER = 'Grace'\n",
"\n",
"dataset = []\n",
"for name, polarity in [(CAT_LOVER, True), (DOG_LOVER, False)]:\n",
" liked, disliked = (CATS, DOGS) if polarity else (DOGS, CATS)\n",
" for kind in liked:\n",
" dataset.append(f'{name}: {kind}? I love them!')\n",
" dataset.append(f'{name}: Hey look at those {kind}, that\\'s pretty cool')\n",
"\n",
" for kind in disliked:\n",
" dataset.append(f'{name}: {kind}? I hate them!')\n",
" dataset.append(f'{name}: Oh no, some {kind}! How scary!')\n",
"\n",
"tokenized_data = [jnp.asarray(tokenizer.encode(ex)) for ex in dataset]\n",
"max_len = max(ex.shape[0] for ex in tokenized_data)\n",
"# Pad the data to speed up jitting. Not worrying about masking due to laziness.\n",
"tokenized_data = [jnp.pad(ex, (0, max_len - ex.shape[0])) for ex in tokenized_data]\n",
"\n",
"jitted_model = jax.jit(quantized_plus_lora_fn)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "NZFLWJgxYqfh"
},
"outputs": [],
"source": [
"def make_prediction(params, prefix):\n",
" tokens = jnp.asarray(tokenizer.encode(prefix))\n",
" logits = jitted_model(params, tokens[None]).logits\n",
"\n",
" logprobs = jnp.exp(jax.nn.log_softmax(logits[0, -1]))\n",
" pred_probs, pred_words = jax.lax.top_k(logprobs, 5)\n",
"\n",
" print(f'Predictions for: \"{prefix}\"')\n",
" for i, (word_id, prob) in enumerate(zip(pred_words, pred_probs), 1):\n",
" print(f'{i}. {tokenizer.decode([word_id])} - {prob:.2%}')\n",
" print()\n",
"\n",
"test_examples = [\n",
" f'{CAT_LOVER}: jaguars? I',\n",
" f'{DOG_LOVER}: jaguars? I'\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yT7hOBnYS-AC"
},
"source": [
"Let's look at the next word predictions of the unmodified model:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eew7ihGJTD85"
},
"outputs": [],
"source": [
"for ex in test_examples:\n",
" make_prediction(combined_params, ex)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BrSL1MgSDXfO"
},
"source": [
"Next we set up a standard training loop. The only difference is that we keep the train/freeze params separate for the optimizer. There's no differences needed for the quantization.\n",
"\n",
"I'll just train with a batch size of 1 here since I don't want to bother with masking, but the transformed model function is fully compatible with vmap etc."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "52QdkmIxDHk-"
},
"outputs": [],
"source": [
"def loss_fn(train_params, freeze_params, seq):\n",
" inputs = seq[:-1]\n",
" targets = seq[1:]\n",
"\n",
" combined_params = (freeze_params, train_params)\n",
" logits = quantized_plus_lora_fn(combined_params, inputs[None]).logits[0]\n",
" logprobs = jax.nn.log_softmax(logits)\n",
" losses = -jnp.take_along_axis(logprobs, targets[:, None], axis=-1)\n",
" return jnp.mean(losses)\n",
"\n",
"optimizer = optax.adamw(learning_rate=1e-4, weight_decay=1e-4)\n",
"opt_state = optimizer.init(combined_params[1])\n",
"\n",
"@jax.jit\n",
"def update_fn(combined_params, opt_state, example):\n",
" freeze_params, train_params = combined_params\n",
"\n",
" # The main thing is that we have to split up the params here so that JAX knows what to differentiate with respect to\n",
" loss, grads = jax.value_and_grad(loss_fn)(train_params, freeze_params, example)\n",
"\n",
" updates, opt_state = optimizer.update(grads, opt_state, params=train_params)\n",
" new_train_params = optax.apply_updates(train_params, updates)\n",
" return (freeze_params, new_train_params), opt_state, loss"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cj2d1xIqFJw3"
},
"outputs": [],
"source": [
"bar = trange(50)\n",
"for epoch in bar:\n",
" key, = jax.random.split(key, 1)\n",
" permutation = jax.random.permutation(key, jnp.arange(len(dataset)))\n",
" total_loss = 0\n",
" for index in permutation:\n",
" example = tokenized_data[index]\n",
" combined_params, opt_state, loss = update_fn(combined_params, opt_state, example)\n",
" total_loss += loss\n",
" bar.set_description(f'Epoch {epoch} - Loss: {total_loss / len(tokenized_data):.3e}')"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IMFZwE8qeSUl"
},
"source": [
"The trained LoRA parameters give us a model which predicts that Alan will love jaguars, and Grace will hate them:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GIgThnapFQS6"
},
"outputs": [],
"source": [
"for example in test_examples:\n",
" make_prediction(combined_params, example)\n",
" print()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "92W8jCjQeZ9J"
},
"outputs": [],
"source": []
}
],
"metadata": {
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"dda94429aac549438c343dd2edc24e0a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_ee547e0a59eb422d975f02daaf672629",
"IPY_MODEL_989350afa53349d39342297bfb4eca95",
"IPY_MODEL_9954285510cb4ebab4deb9d6558dd5ed"
],
"layout": "IPY_MODEL_663b81ccdf0c4b02af24b1cd81b7daa6"
}
},
"ee547e0a59eb422d975f02daaf672629": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_6bf5f6cd0df14197b8da4303b43cb655",
"placeholder": "​",
"style": "IPY_MODEL_693fb73b9dc0411eab41c5fede4b611c",
"value": "Downloading (…)lve/main/config.json: 100%"
}
},
"989350afa53349d39342297bfb4eca95": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_db0e6b46120844dc8620fbd2a4076d0d",
"max": 896,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_9872d010314e492f9a85cbb59edb8cc3",
"value": 896
}
},
"9954285510cb4ebab4deb9d6558dd5ed": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_1478a80283ff43f193bd9a458a6c8274",
"placeholder": "​",
"style": "IPY_MODEL_199682514a0e43b5a13244a818722d90",
"value": " 896/896 [00:00&lt;00:00, 13.1kB/s]"
}
},
"663b81ccdf0c4b02af24b1cd81b7daa6": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"6bf5f6cd0df14197b8da4303b43cb655": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"693fb73b9dc0411eab41c5fede4b611c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"db0e6b46120844dc8620fbd2a4076d0d": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"9872d010314e492f9a85cbb59edb8cc3": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"1478a80283ff43f193bd9a458a6c8274": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"199682514a0e43b5a13244a818722d90": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"cdb5e7259f57470296128e76fdf5e6ea": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_9b31610503d84238b9e7a2321ff1b10b",
"IPY_MODEL_32d7477699eb461584db4072a7c97052",
"IPY_MODEL_6c4eafec00674c838c628394c7c68faa"
],
"layout": "IPY_MODEL_a63b31cd45b24408b5a9068d1ecd1b4b"
}
},
"9b31610503d84238b9e7a2321ff1b10b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_715506a0e5694e7ca6061ee99bd92382",
"placeholder": "​",
"style": "IPY_MODEL_3c4dbdf570eb46b690cd500476020064",
"value": "Downloading (…)model.bin.index.json: 100%"
}
},
"32d7477699eb461584db4072a7c97052": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_bdd5160d52de4f9895f8e4dd7f7ec20a",
"max": 55432,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_c4eee8d4242045bba3d38597dbb414b3",
"value": 55432
}
},
"6c4eafec00674c838c628394c7c68faa": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_2c47cb97a18f4ee89bd945b1ca11347d",
"placeholder": "​",
"style": "IPY_MODEL_4a9f0ebd9d7b4192a25ef75250c03deb",
"value": " 55.4k/55.4k [00:00&lt;00:00, 691kB/s]"
}
},
"a63b31cd45b24408b5a9068d1ecd1b4b": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"715506a0e5694e7ca6061ee99bd92382": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"3c4dbdf570eb46b690cd500476020064": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"bdd5160d52de4f9895f8e4dd7f7ec20a": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"c4eee8d4242045bba3d38597dbb414b3": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"2c47cb97a18f4ee89bd945b1ca11347d": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"4a9f0ebd9d7b4192a25ef75250c03deb": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"7d0234ed0d35487e826565df214a2cc9": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_46ad2aad7e474c73be0bfb1fef55a4bd",
"IPY_MODEL_6d811cfd512f4916b83640f9297bc1bd",
"IPY_MODEL_8c416aa8a81447cba3527d287aeb2f35"
],
"layout": "IPY_MODEL_8e073507368745088aa356b23f044f0b"
}
},
"46ad2aad7e474c73be0bfb1fef55a4bd": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_902022c949e448eeb8ce3e31d1ef4153",
"placeholder": "​",
"style": "IPY_MODEL_8bcfb6b0b56d4046a839c0f9ae4c16b1",
"value": "Downloading shards: 100%"
}
},
"6d811cfd512f4916b83640f9297bc1bd": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_7de7565229de41149b770fe25714a847",
"max": 2,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_52e158b04b8c4858929bf3739f7c78df",
"value": 2
}
},
"8c416aa8a81447cba3527d287aeb2f35": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_a12f0daf615547faa649107033a93011",
"placeholder": "​",
"style": "IPY_MODEL_3067249f4bc64398906fceea8d1e5934",
"value": " 2/2 [04:42&lt;00:00, 125.83s/it]"
}
},
"8e073507368745088aa356b23f044f0b": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"902022c949e448eeb8ce3e31d1ef4153": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"8bcfb6b0b56d4046a839c0f9ae4c16b1": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"7de7565229de41149b770fe25714a847": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"52e158b04b8c4858929bf3739f7c78df": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"a12f0daf615547faa649107033a93011": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"3067249f4bc64398906fceea8d1e5934": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"5fe6035fe6e8426580d632bcedc3d19b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_fe81c32260ca47c7bcd4fd6230ed3c81",
"IPY_MODEL_ec4a1c28bbb54799a2278ee1e1d65383",
"IPY_MODEL_68cc8181fd484af1b9f42b1c9b5f8d75"
],
"layout": "IPY_MODEL_0d20b32672e34c7295634090e6cebf27"
}
},
"fe81c32260ca47c7bcd4fd6230ed3c81": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_24f189322a134a20900084de1e08c376",
"placeholder": "​",
"style": "IPY_MODEL_aae1c621a5b94960ace82d52253ccaf7",
"value": "Downloading (…)l-00001-of-00002.bin: 100%"
}
},
"ec4a1c28bbb54799a2278ee1e1d65383": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_ff2779f06f954e35a5eec9a0a02684f1",
"max": 9449929179,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_c00fb1fde3f74d82ac6b3b5bc19a9796",
"value": 9449929179
}
},
"68cc8181fd484af1b9f42b1c9b5f8d75": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_d50444fed9a746ff8281239c4c1a3058",
"placeholder": "​",
"style": "IPY_MODEL_1110d48d7a664208bcd57f3c67aecccd",
"value": " 9.45G/9.45G [03:47&lt;00:00, 41.6MB/s]"
}
},
"0d20b32672e34c7295634090e6cebf27": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"24f189322a134a20900084de1e08c376": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"aae1c621a5b94960ace82d52253ccaf7": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"ff2779f06f954e35a5eec9a0a02684f1": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"c00fb1fde3f74d82ac6b3b5bc19a9796": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"d50444fed9a746ff8281239c4c1a3058": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"1110d48d7a664208bcd57f3c67aecccd": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"c22c1289367046babeb0c9409c181f8b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_eb3f2e923426469e9b92a51187f60f60",
"IPY_MODEL_83784cb918974e4c8664f78561ea8cf8",
"IPY_MODEL_008aa06de0fd4ec789d3d42e83dbd16e"
],
"layout": "IPY_MODEL_33217d952d544284937852eeb17323fc"
}
},
"eb3f2e923426469e9b92a51187f60f60": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_c83d30547b08413f90300abdf318b123",
"placeholder": "​",
"style": "IPY_MODEL_2b8df3f5200e4c9fa59f057c4586a700",
"value": "Downloading (…)l-00002-of-00002.bin: 100%"
}
},
"83784cb918974e4c8664f78561ea8cf8": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_8fc75f12da864868a4d0ab9c1949d1ac",
"max": 1949494999,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_b00196f8517c415ea5388bf1a72ffb37",
"value": 1949494999
}
},
"008aa06de0fd4ec789d3d42e83dbd16e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_09ae4873193f41baa133879a8d3f5905",
"placeholder": "​",
"style": "IPY_MODEL_bd9db9a757db4ccfa68154524ac95dbb",
"value": " 1.95G/1.95G [00:53&lt;00:00, 44.3MB/s]"
}
},
"33217d952d544284937852eeb17323fc": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"c83d30547b08413f90300abdf318b123": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"2b8df3f5200e4c9fa59f057c4586a700": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"8fc75f12da864868a4d0ab9c1949d1ac": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"b00196f8517c415ea5388bf1a72ffb37": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"09ae4873193f41baa133879a8d3f5905": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"bd9db9a757db4ccfa68154524ac95dbb": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"410478d1c6164c679702f1b959ec6dc6": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_34ae822850454b4b8f92dae185943fe4",
"IPY_MODEL_44bd417858fb4982bd97cc07b3f03b5d",
"IPY_MODEL_84ddc0bd3bfd4df9a300a0ea01dafc06"
],
"layout": "IPY_MODEL_60e20079809641db96f76aa61980ac74"
}
},
"34ae822850454b4b8f92dae185943fe4": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_a82fed6b81274263a1c88e41063356f4",
"placeholder": "​",
"style": "IPY_MODEL_14c39b554476480d8ed17a4188ffc261",
"value": "Loading checkpoint shards: 100%"
}
},
"44bd417858fb4982bd97cc07b3f03b5d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_1b9abcbc8e054a069d9bf0710329928b",
"max": 2,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_d85f0a6e82b648b0899e7006b5a7e740",
"value": 2
}
},
"84ddc0bd3bfd4df9a300a0ea01dafc06": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_df82eb9aaa354a4395a5304ae038a828",
"placeholder": "​",
"style": "IPY_MODEL_96b5cf9d5b7846bfa61523ee1ec4db69",
"value": " 2/2 [01:32&lt;00:00, 43.54s/it]"
}
},
"60e20079809641db96f76aa61980ac74": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"a82fed6b81274263a1c88e41063356f4": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"14c39b554476480d8ed17a4188ffc261": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"1b9abcbc8e054a069d9bf0710329928b": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"d85f0a6e82b648b0899e7006b5a7e740": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"df82eb9aaa354a4395a5304ae038a828": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"96b5cf9d5b7846bfa61523ee1ec4db69": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"3458c2feac6b41ebbf6ed872d9112df0": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_d8cf12f088a94b4da262d8d2c134e89f",
"IPY_MODEL_2973bec2ccff49168ba84ddc1f326e8a",
"IPY_MODEL_58230fe17b7840c983a59a807a051f2c"
],
"layout": "IPY_MODEL_b4687344280545029c6505de0d277b97"
}
},
"d8cf12f088a94b4da262d8d2c134e89f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_910e96c5ff324caba366ede006d164ab",
"placeholder": "​",
"style": "IPY_MODEL_d8138fd266e44ab3963a274ef7018b0c",
"value": "Downloading (…)neration_config.json: 100%"
}
},
"2973bec2ccff49168ba84ddc1f326e8a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_0043c749b6ad4d27a251abbf5d3644ed",
"max": 147,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_ccf2b81c4a774692980b024885d00872",
"value": 147
}
},
"58230fe17b7840c983a59a807a051f2c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_47302718457149c9bd98b2ef0b702d52",
"placeholder": "​",
"style": "IPY_MODEL_83d6be00be514667a08472ba48a7d42f",
"value": " 147/147 [00:00&lt;00:00, 4.48kB/s]"
}
},
"b4687344280545029c6505de0d277b97": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"910e96c5ff324caba366ede006d164ab": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"d8138fd266e44ab3963a274ef7018b0c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"0043c749b6ad4d27a251abbf5d3644ed": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"ccf2b81c4a774692980b024885d00872": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"47302718457149c9bd98b2ef0b702d52": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"83d6be00be514667a08472ba48a7d42f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"6d2432a98bf84a77a215f03e69a8a79f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_9ecd6788ffb04d37ae547f518246fc6b",
"IPY_MODEL_da31333b1c844b56830916b794ce9833",
"IPY_MODEL_a90c9881a7d0420cb7e230f2b19f25f4"
],
"layout": "IPY_MODEL_7fc0552e04b247f5a8b5c8dc7ebfbfe1"
}
},
"9ecd6788ffb04d37ae547f518246fc6b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_5d21f67dc7aa4297b7115299c2589da7",
"placeholder": "​",
"style": "IPY_MODEL_5562aa475ee846a89442b793da39c802",
"value": "Downloading spiece.model: 100%"
}
},
"da31333b1c844b56830916b794ce9833": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_46fd98456a604139a948a97056fb6842",
"max": 791656,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_51bebf1ec8fe4e798bbba4bbf158bd6d",
"value": 791656
}
},
"a90c9881a7d0420cb7e230f2b19f25f4": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_2f2b07d552364493965c7811042d73c5",
"placeholder": "​",
"style": "IPY_MODEL_fab5d066dc794d1dbb17b282081b1d92",
"value": " 792k/792k [00:00&lt;00:00, 3.68MB/s]"
}
},
"7fc0552e04b247f5a8b5c8dc7ebfbfe1": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"5d21f67dc7aa4297b7115299c2589da7": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"5562aa475ee846a89442b793da39c802": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"46fd98456a604139a948a97056fb6842": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"51bebf1ec8fe4e798bbba4bbf158bd6d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"2f2b07d552364493965c7811042d73c5": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"fab5d066dc794d1dbb17b282081b1d92": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"d20b4ddee92c4d22a81bd819e8314fa3": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_556d33b982cc4989a03fc4b29bc7e84a",
"IPY_MODEL_03675f31e76a4135a2ae239f3ee42655",
"IPY_MODEL_8231b974ae924b68959d606e073d66ad"
],
"layout": "IPY_MODEL_4bfd7fde14514f41bea4edcb878cfc3c"
}
},
"556d33b982cc4989a03fc4b29bc7e84a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_d439f8d48f754ed1976392cadfaac95f",
"placeholder": "​",
"style": "IPY_MODEL_86b9a19a4ecf4e52974e1dc84335aca2",
"value": "Downloading (…)/main/tokenizer.json: 100%"
}
},
"03675f31e76a4135a2ae239f3ee42655": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_1b6c1e021fa64948ae533a11f029bd50",
"max": 1389353,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_71b903a4156149bbb714e1f3841f13b2",
"value": 1389353
}
},
"8231b974ae924b68959d606e073d66ad": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_a011aa10bd374418b030f7cd2f602255",
"placeholder": "​",
"style": "IPY_MODEL_e52faa2810604688844ddf624b0a1400",
"value": " 1.39M/1.39M [00:00&lt;00:00, 15.9MB/s]"
}
},
"4bfd7fde14514f41bea4edcb878cfc3c": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"d439f8d48f754ed1976392cadfaac95f": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"86b9a19a4ecf4e52974e1dc84335aca2": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"1b6c1e021fa64948ae533a11f029bd50": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"71b903a4156149bbb714e1f3841f13b2": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"a011aa10bd374418b030f7cd2f602255": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"e52faa2810604688844ddf624b0a1400": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
}
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
from collections import defaultdict
from copy import deepcopy
from functools import partial, reduce
import numpy as np
import warnings
import jax
import jax.numpy as jnp
from jax._src.core import Literal
from jax.util import safe_map
from tqdm import tqdm
from .gptq import gptq, pack_matrix, QuantizedMatrix
def tree_size_bytes(tree):
return jax.tree_util.tree_reduce(
lambda x, y: x + y,
jax.tree_util.tree_map(
lambda x: x.size * x.itemsize,
tree
),
0
)
def quantize(
fn,
params,
inputs,
block_size=128,
actorder=False,
damping=0.01,
use_quantized_activations=True,
use_fp64=False,
use_params_fp32=False
):
"""
Run the GPT-Q algorithm on a function to produce quantized versions of its parameters
Arguments:
fn: The function to be transformed. It should take two arguments:
1. A pytree of parameters to be quantized. This corresponds to the `params` pytree from libraries such as Flax/Haiku
2. A pytree of other arguments. If the original model takes more than one extra argument, you can write a wrapper which takes a tuple as the second argument. TODO: handle varargs
params: The params pytree. Buffers in this tree may be freed to save memory, so do not re-use it after calling this function.
inputs: A list of batches of inputs. If your model needs to be vmapped to handle batches, do that before calling quantize.
"""
with jax.disable_jit():
jaxpr_args = (params, inputs[0])
if use_params_fp32:
jaxpr_args = jax.tree_util.tree_map(
lambda x: jax.ShapeDtypeStruct(x.shape, jnp.float32) if x.dtype.kind == 'f' else x,
jaxpr_args
)
closed_jaxpr = jax.make_jaxpr(fn)(*jaxpr_args)
params = jax.device_put(params, jax.devices('cpu')[0])
inputs = jax.device_put(inputs, jax.devices('cpu')[0])
argnums = set()
param_args, param_struct = jax.tree_util.tree_flatten(params)
input_args = [jax.tree_util.tree_leaves(inp) for inp in inputs]
input_args = [list(arg) for arg in zip(*input_args)]
argnums = set(range(0, len(param_args)))
result = _eval_and_quantize(
closed_jaxpr.jaxpr,
closed_jaxpr.literals,
argnums,
*param_args,
*input_args,
block_size=block_size,
actorder=actorder,
damping=damping,
use_quantized_activations=use_quantized_activations,
use_fp64=use_fp64,
use_params_fp32=use_params_fp32
)
for ind, quantized_param in result.items():
param_args[ind] = quantized_param
return jax.tree_util.tree_unflatten(param_struct, param_args)
def _get_delete_points(jaxpr):
deps = defaultdict(set)
for i, eqn in enumerate(jaxpr.eqns):
for var in set(v for v in eqn.invars if not isinstance(v, Literal)):
deps[var].add(i)
deps = dict(deps)
delete_vars = []
for i, eqn in enumerate(jaxpr.eqns):
eqn_delete = []
for var in set(v for v in eqn.invars if not isinstance(v, Literal)):
deps[var].remove(i)
if not deps[var]:
eqn_delete.append(var)
del deps[var]
delete_vars.append(eqn_delete)
return delete_vars
def _maybe_delete(val):
if not val.is_deleted():
val.device_buffer.delete()
def _eval_and_quantize(
jaxpr,
consts,
argnums,
*args,
block_size=128,
actorder=False,
damping=0.01,
use_quantized_activations=True,
use_fp64=False,
use_params_fp32=False
):
tpu = jax.devices()
#cpu = jax.devices('cpu')[0]
#gpu = jax.devices('gpu')[0]
# Args are all either params or lists of tensors
quantized_results = {}
name_to_pos = {}
n_batches = len(next(a for i, a in enumerate(args) if i not in argnums))
# Everything in here should be on GPU
envs = [{} for _ in range(n_batches)]
# Map from var name to a tuple of value, original_name, and a stack of transformations to map it back to orig param shape
param_env = {}
for index, name in enumerate(jaxpr.invars):
if index in argnums:
param_env[name] = (args[index], name, ())
name_to_pos[name] = index
else:
for i in range(n_batches):
envs[i][name] = args[index][i]
def delete(name):
if name not in envs[0]:
return
for env in envs:
env[name].device_buffer.delete()
del env[name]
delete_points = _get_delete_points(jaxpr)
const_env = {name: val for name, val in zip(jaxpr.constvars, consts)}
pos = 0
bar = tqdm(desc='Quantizing')
while True:
bar.update(1)
next_pos, needed_names, matmul_handler, updated_param_env = update_params_to_next_matmul(
eqns=jaxpr.eqns,
start_pos=pos,
delete_points=delete_points,
param_env=param_env,
env=envs[0]
)
if next_pos is None:
break
block_param_env = {
name: jax.device_put(param_env[name][0], gpu)
for name in needed_names if name in param_env
}
if use_params_fp32:
for k, v in block_param_env.items():
if v.dtype.kind == 'f':
block_param_env[k] = v.astype(jnp.float32)
print(f'Current env size: {tree_size_bytes(envs):.2e} bytes')
print(f'Current param env size: {tree_size_bytes(block_param_env):.2e} bytes')
delete_keys = set(var for i in range(pos, next_pos) for var in delete_points[i])
segment_eqns = jaxpr.eqns[pos:next_pos]
# If a parameter has been transformed keep it in the param env instead of the individual envs
drop_env_keys = set(k for k in updated_param_env if k not in param_env)
missing_keys = set(k for k in param_env if k not in updated_param_env)
block_fn = jax.jit(partial(run_segment, segment_eqns, pos, delete_points, drop_env_keys))
for i, env in enumerate(envs):
#gpu_env = jax.device_put(env, gpu)
tpu_env = jax.device_put(env, tpu)
new_env = block_fn(block_param_env, tpu_env, const_env)
envs[i] = new_env
#envs[i] = jax.device_put(new_env, cpu)
#jax.tree_map(_maybe_delete, (gpu_env, new_env))
for param in block_param_env.values():
param.device_buffer.delete()
del block_param_env
param_env = updated_param_env
#(jax.device_put(0., gpu) + 0).block_until_ready()
matmul_eqn = jaxpr.eqns[next_pos]
all_args = []
if sum(argname in param_env for argname in matmul_eqn.invars) > 1:
raise NotImplementedError('Currently only one quantize target is supported per op')
quantize_argname = next(argname for argname in matmul_eqn.invars if argname in param_env)
for argname in matmul_eqn.invars:
if argname in param_env:
all_args.append(param_env[argname][0])
else:
all_args.append([env[argname] for env in envs])
all_args = [jax.device_put(arg, gpu) for arg in all_args]
handler_coro = matmul_handler(all_args)
w, xs = next(handler_coro)
quantized_w, quantize_params = gptq(
W=w,
xs=xs,
block_size=block_size,
actorder=actorder,
damping=damping,
use_fp64=use_fp64
)
assert quantized_w.shape == w.shape
try:
handler_coro.send((quantized_w, quantize_params['scale'], quantize_params['zero']))
assert False, 'Handler should have stopped'
except StopIteration as e:
quantized_w, quantize_params['scale'], quantize_params['zero'], contraction_axis = e.value
outvars = jaxpr.eqns[next_pos].outvars
delete_indices = [i for i, name in enumerate(matmul_eqn.invars) if name != quantize_argname]
do_eval = jax.jit(partial(eval_eqn, matmul_eqn))
matmul_w_arg = quantized_w if use_quantized_activations else param_env[quantize_argname][0]
if use_params_fp32:
matmul_w_arg = matmul_w_arg.astype(jnp.float32)
matmul_w_arg = jax.device_put(matmul_w_arg, gpu)
for env in envs:
gpu_args = [
matmul_w_arg
if argname == quantize_argname else
env[argname]
for argname in matmul_eqn.invars
]
gpu_args = jax.device_put(gpu_args, gpu)
results = do_eval(*gpu_args)
if tree_size_bytes(results) > 1e8:
# This should offload stuff like the final logits to the CPU
cpu_results = jax.device_put(results, cpu)
jax.tree_map(lambda x: x.is_deleted() or x.device_buffer.delete(), results)
results = cpu_results
if matmul_eqn.primitive.multiple_results:
for outvar, value in zip(outvars, results):
env[outvar] = value
else:
env[outvars[0]] = results
for name in delete_points[next_pos]:
if name in env:
_maybe_delete(env[name])
del env[name]
#for i in delete_indices:
# gpu_args[i].device_buffer.delete()
#(jax.device_put(0., gpu) + 0).block_until_ready()
#for name in delete_points[next_pos]:
# delete(name)
# TODO: Instead of catching duplicate quantizations here avoid doing the calculation in the first place
orig_w, orig_name, inv_transforms = param_env[quantize_argname]
write_arg = name_to_pos[orig_name]
if write_arg not in quantized_results:
packed_result = pack_matrix(quantized_w, quantize_params, contraction_axis)
un_transformed = reduce(lambda x, f: f(x), inv_transforms, packed_result)
quantized_results[write_arg] = jax.device_put(un_transformed, cpu)
if quantize_argname not in delete_points[next_pos]:
cpu_quantized_w = jax.device_put(quantized_w, cpu)
param_env[quantize_argname] = cpu_quantized_w, orig_name, inv_transforms
orig_w.device_buffer.delete()
elif quantize_argname in delete_points[next_pos]:
orig_w.device_buffer.delete()
del param_env[quantize_argname]
quantized_w.device_buffer.delete()
#(jax.device_put(0., gpu) + 0).block_until_ready()
pos = next_pos + 1
return quantized_results
def update_params_to_next_matmul(eqns, start_pos, delete_points, param_env, env):
new_param_env = {k: v for k, v in param_env.items()}
env_shapes = {k: jax.ShapeDtypeStruct(v.shape, v.dtype) for k, v in env.items()}
needed_names = set()
for i, eqn in enumerate(eqns[start_pos:], start_pos):
invars = eqn.invars
op_name = eqn.primitive.name
if op_name in PARAM_TRANSFORMS:
arg, = invars
needed_names.add(arg)
if arg in new_param_env and len(new_param_env[arg][0].shape) > 1:
val, orig_name, transforms = new_param_env[arg]
new_transform = PARAM_TRANSFORMS[op_name](eqn, val)
new_name, = eqn.outvars
new_val = eval_eqn(eqn, val)
new_param_env[new_name] = new_val, orig_name, (transforms + (new_transform,))
if arg in delete_points[i]: #TODO: Become certain that making this just a soft check was fine
del new_param_env[arg]
else:
warnings.warn(f'Transformation `{op_name}` is applied to a target parameter of shape {new_param_env[arg][0].shape} which is later reused. This may lead to this parameter not being quantized, or it being quantized poorly.')
continue
arg_shapes = [invar.aval for invar in invars]
args_are_targets = [(
False if isinstance(v, Literal) else
(v in new_param_env and len(new_param_env[v][0].shape) > 1)
) for v in invars]
if any(args_are_targets):
if op_name == 'pjit':
warnings.warn(f'Quantization does not descend into pjit')
if op_name in PRIMITIVE_TO_MATMUL:
predicate, handler = PRIMITIVE_TO_MATMUL[op_name]
if predicate(eqn, args_are_targets, arg_shapes):
return i, needed_names, partial(handler, eqn, args_are_targets), new_param_env
else:
warnings.warn(f'Operation {eqn.primitive.name} not supported for quantization')
out_shapes = jax.eval_shape(partial(eval_eqn, eqn), *arg_shapes)
if not eqn.primitive.multiple_results:
out_shapes = [out_shapes]
safe_map(env_shapes.__setitem__, eqn.outvars, out_shapes)
needed_names.update(v for v in invars if not isinstance(v, Literal))
return None, needed_names, None, None
def run_segment(eqns, start_pos, delete_points, drop_env_keys, param_env, env, const_env):
env = dict(env)
def read(v):
if isinstance(v, Literal):
return v.val
if v in param_env:
return param_env[v]
if v in env:
return env[v]
return const_env[v]
def write(v, val):
env[v] = val
for i, eqn in enumerate(eqns, start_pos):
eqn_args = safe_map(read, eqn.invars)
ans = eval_eqn(eqn, *eqn_args)
if eqn.primitive.multiple_results:
safe_map(write, eqn.outvars, ans)
else:
write(eqn.outvars[0], ans)
for varname in delete_points[i]:
if varname in env:
del env[varname]
for key in drop_env_keys:
env.pop(key, None)
return env
def dot_general_predicate(eqn, args_are_targets, args):
(lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = eqn.params['dimension_numbers']
if sum(args_are_targets) > 1:
warnings.warn('Quantizing two parameters which are multiplied together is not supported')
return False
if lhs_batch or rhs_batch:
warnings.warn('Quantizing batched matmuls is not supported')
return False
if len(lhs_contract) > 1 or len(rhs_contract) > 1:
warnings.warn('Quantizing dots with more than one contraction is not supported')
return False
return True
@partial(jax.jit, static_argnums=(1, 2))
def permute_to_matrix(w, permutation, keep_first):
w = jnp.transpose(w, permutation)
out_shape = (w.shape[0], -1) if keep_first else (-1, w.shape[-1])
w = jnp.reshape(w, out_shape)
return w
@partial(jax.jit, static_argnums=(1, 2))
def to_original_shape(w, shape, restore_permutation):
return jnp.transpose(
jnp.reshape(w, shape),
restore_permutation
)
def handle_dot_general(eqn, args_are_targets, args):
lhs, rhs = args
((lhs_contract,), (rhs_contract,)), _ = eqn.params['dimension_numbers']
if args_are_targets[0]:
w, xs = lhs, rhs
w_contract, x_contract = lhs_contract, rhs_contract
else:
w, xs = rhs, lhs
w_contract, x_contract = rhs_contract, lhs_contract
orig_w_shape = w.shape
w_permutation = None
if w_contract != 0 or len(w.shape) > 2:
w_permutation = tuple([w_contract, *(i for i in range(len(w.shape)) if i != w_contract)])
w = permute_to_matrix(w, w_permutation, True)
assert isinstance(xs, list)
x_permutation = None
if x_contract != len(xs[0].shape) - 1:
x_permutation = tuple([*(i for i in range(len(xs[0].shape)) if i != x_contract), x_contract])
prepared_xs = []
for x in xs:
if x_permutation is not None:
x = permute_to_matrix(x, x_permutation, False)
prepared_xs.append(x)
quantized_w, scales, zeros = yield w, prepared_xs
if w_permutation:
unpermute = tuple(np.argsort(w_permutation))
shape = tuple(orig_w_shape[i] for i in w_permutation)
quantized_w = to_original_shape(quantized_w, shape, unpermute)
scale_shape = tuple(d for i, d in enumerate(orig_w_shape) if i != w_contract)
scales = jnp.reshape(scales, scale_shape)
zeros = jnp.reshape(zeros, scale_shape)
return quantized_w, scales, zeros, int(w_contract)
def conv_predicate(eqn, args_are_targets, args):
inp_is_target, kernel_is_target = args_are_targets
if inp_is_target:
warnings.warn('Only quantizing the kernel of a conv is supported, not the input')
if not kernel_is_target:
return False
params = eqn.params
if any(val != 1 for val in params['window_strides']):
warnings.warn('Currently only quantizing convs with stride 1 is supported')
return False
if any(val != 1 for val in params['rhs_dilation']):
warnings.warn('Currently only quantizing convs with dilation 1 is supported')
return False
if params['feature_group_count'] != 1:
warnings.warn('Currently only quantizing convs with feature group count 1 is supported')
return False
if params['batch_group_count'] != 1:
warnings.warn('Currently only quantizing convs with batch group count 1 is supported')
return False
# Each is: Batch, feature, spatial...
kernel_spatial_dims = params['dimension_numbers'][1][2:]
kernel_shape = args[1].shape
for spatial_dim in kernel_spatial_dims:
if kernel_shape[spatial_dim] != 1:
warnings.warn('Currently only quantizing convs with 1x..x1 kernels are supported')
return False
return True
def handle_conv(eqn, args_are_targets, args):
inps, kernel = args
inp_shape = inps[0].shape
kernel_shape = kernel.shape
(inp_batch_dim, inp_feature_dim, inp_spatial_dims), (kernel_out_dim, kernel_in_dim, *kernel_spatial_dims), _ = eqn.params['dimension_numbers']
flat_kernel = jnp.squeeze(kernel, kernel_spatial_dims)
needs_transpose = kernel_out_dim < kernel_in_dim
if needs_transpose:
flat_kernel = flat_kernel.T
inp_permutation = None
if inp_feature_dim != len(inp_shape) - 1:
inp_permutation = tuple([*(i for i in range(len(inp_shape)) if i != inp_feature_dim), inp_feature_dim])
prepared_inps = []
for inp in inps:
if inp_permutation is not None:
inp = permute_to_matrix(inp, inp_permutation, False)
prepared_inps.append(inp)
quantized_kernel, scales, zeros = yield flat_kernel, prepared_inps
if needs_transpose:
quantized_kernel = quantized_kernel.T
for dim in sorted(kernel_spatial_dims):
quantized_kernel = jnp.expand_dims(quantized_kernel, dim)
scale_dim = dim if dim < inp_feature_dim else dim - 1
scales = jnp.expand_dims(scales, scale_dim)
zeros = jnp.expand_dims(zeros, scale_dim)
return quantized_kernel, scales, zeros, kernel_in_dim
def eval_eqn(eqn, *args):
subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)
ans = eqn.primitive.bind(*subfuns, *args, **bind_params)
return ans
PRIMITIVE_TO_MATMUL = {
'dot_general': (dot_general_predicate, handle_dot_general),
'conv_general_dilated': (conv_predicate, handle_conv)
}
def inverse_transpose(eqn, arg):
unpermute = tuple(np.argsort(eqn.params['permutation']))
def inverse(quantized_matrix):
prev_contract_axis = quantized_matrix.contraction_axis
new_contraction_axis = unpermute[prev_contract_axis]
new_int_weight = jax.lax.transpose(quantized_matrix.int_weight, permutation=unpermute)
unpermute_scale = [
i if i < prev_contract_axis else i - 1
for i in unpermute
if i != prev_contract_axis
]
new_scale = jax.lax.transpose(quantized_matrix.scale, permutation=unpermute_scale)
new_zero = jax.lax.transpose(quantized_matrix.zero, permutation=unpermute_scale)
return QuantizedMatrix(
int_weight=new_int_weight,
scale=new_scale,
zero=new_zero,
contraction_axis=new_contraction_axis
)
return inverse
def inverse_convert_type(eqn, arg):
return lambda x: x
PARAM_TRANSFORMS = {
'transpose': inverse_transpose,
'convert_element_type': inverse_convert_type,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment