Skip to content

Instantly share code, notes, and snippets.

@texasdave2
Created September 18, 2023 20:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save texasdave2/ab3720476c4b68cc6ce081bc137bc1fb to your computer and use it in GitHub Desktop.
Save texasdave2/ab3720476c4b68cc6ce081bc137bc1fb to your computer and use it in GitHub Desktop.
simple_nlp_example.ipynb
{
"cells": [
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"## currently gives error:\n",
"\n",
"##ValueError: To launch a multi-GPU training from your notebook, the `Accelerator` should only be initialized inside your training function. Restart your notebook and make sure no cells initializes an `Accelerator`.\n",
"\n",
"##raised issue in github \n",
"\n",
"## https://github.com/huggingface/accelerate/issues/1985#issuecomment-1723770308\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PdpySQpViE_U"
},
"source": [
"Before we can browse the rest of the notebook, we need to install the dependencies: this example uses `datasets` and `transformers`. To use TPUs on colab, we need to install `torch_xla` and the last line install `accelerate` from source since we the features we are using are very recent and not released yet."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: huggingface-hub in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (0.17.2)\n",
"Requirement already satisfied: filelock in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from huggingface-hub) (3.9.0)\n",
"Requirement already satisfied: fsspec in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from huggingface-hub) (2023.6.0)\n",
"Requirement already satisfied: requests in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from huggingface-hub) (2.31.0)\n",
"Requirement already satisfied: tqdm>=4.42.1 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from huggingface-hub) (4.66.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from huggingface-hub) (6.0.1)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from huggingface-hub) (4.7.1)\n",
"Requirement already satisfied: packaging>=20.9 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from huggingface-hub) (23.1)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from requests->huggingface-hub) (2.0.4)\n",
"Requirement already satisfied: idna<4,>=2.5 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from requests->huggingface-hub) (3.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from requests->huggingface-hub) (1.26.16)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from requests->huggingface-hub) (2023.7.22)\n",
"Token is valid (permission: write).\n",
"Your token has been saved in your configured git credential helpers (store).\n",
"Your token has been saved to /home/demouser/.cache/huggingface/token\n",
"Login successful\n"
]
}
],
"source": [
"## code to auto login to hugging face, avoid the login prompt\n",
"\n",
"!pip install -U huggingface-hub\n",
"\n",
"# get your account token from https://huggingface.co/settings/tokens\n",
"token = 'XXXXXXXXXXXXXXXXXXXXXXX'\n",
"\n",
"from huggingface_hub import login\n",
"login(token=token, add_to_git_credential=True)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mon Sep 18 15:41:04 2023 \n",
"+---------------------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 535.86.05 Driver Version: 535.86.05 CUDA Version: 12.2 |\n",
"|-----------------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|=========================================+======================+======================|\n",
"| 0 Tesla T4 Off | 00000000:0B:00.0 Off | Off |\n",
"| N/A 28C P8 9W / 70W | 2MiB / 16384MiB | 0% Default |\n",
"| | | N/A |\n",
"+-----------------------------------------+----------------------+----------------------+\n",
"| 1 Tesla T4 Off | 00000000:14:00.0 Off | Off |\n",
"| N/A 30C P8 9W / 70W | 2MiB / 16384MiB | 0% Default |\n",
"| | | N/A |\n",
"+-----------------------------------------+----------------------+----------------------+\n",
"| 2 Tesla T4 Off | 00000000:1D:00.0 Off | Off |\n",
"| N/A 32C P8 9W / 70W | 2MiB / 16384MiB | 0% Default |\n",
"| | | N/A |\n",
"+-----------------------------------------+----------------------+----------------------+\n",
" \n",
"+---------------------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=======================================================================================|\n",
"| No running processes found |\n",
"+---------------------------------------------------------------------------------------+\n"
]
}
],
"source": [
"! nvidia-smi"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "cJrsKX5AsnHp",
"outputId": "93ea91d5-7c1b-4531-c75a-5c1f4cb4b43d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: datasets in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (2.14.5)\n",
"Requirement already satisfied: transformers in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (4.34.0.dev0)\n",
"Requirement already satisfied: evaluate in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (0.4.0)\n",
"Requirement already satisfied: torch==2.0.0 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (2.0.0)\n",
"Requirement already satisfied: filelock in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch==2.0.0) (3.9.0)\n",
"Requirement already satisfied: typing-extensions in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch==2.0.0) (4.7.1)\n",
"Requirement already satisfied: sympy in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch==2.0.0) (1.11.1)\n",
"Requirement already satisfied: networkx in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch==2.0.0) (3.1)\n",
"Requirement already satisfied: jinja2 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch==2.0.0) (3.1.2)\n",
"Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch==2.0.0) (11.7.99)\n",
"Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch==2.0.0) (11.7.99)\n",
"Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch==2.0.0) (11.7.101)\n",
"Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch==2.0.0) (8.5.0.96)\n",
"Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch==2.0.0) (11.10.3.66)\n",
"Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch==2.0.0) (10.9.0.58)\n",
"Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch==2.0.0) (10.2.10.91)\n",
"Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch==2.0.0) (11.4.0.1)\n",
"Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch==2.0.0) (11.7.4.91)\n",
"Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch==2.0.0) (2.14.3)\n",
"Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch==2.0.0) (11.7.91)\n",
"Requirement already satisfied: triton==2.0.0 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch==2.0.0) (2.0.0)\n",
"Requirement already satisfied: setuptools in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch==2.0.0) (68.0.0)\n",
"Requirement already satisfied: wheel in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch==2.0.0) (0.38.4)\n",
"Requirement already satisfied: cmake in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from triton==2.0.0->torch==2.0.0) (3.27.4.1)\n",
"Requirement already satisfied: lit in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from triton==2.0.0->torch==2.0.0) (16.0.6)\n",
"Requirement already satisfied: numpy>=1.17 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from datasets) (1.25.2)\n",
"Requirement already satisfied: pyarrow>=8.0.0 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from datasets) (13.0.0)\n",
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from datasets) (0.3.6)\n",
"Requirement already satisfied: pandas in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from datasets) (2.1.0)\n",
"Requirement already satisfied: requests>=2.19.0 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from datasets) (2.31.0)\n",
"Requirement already satisfied: tqdm>=4.62.1 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from datasets) (4.66.1)\n",
"Requirement already satisfied: xxhash in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from datasets) (3.3.0)\n",
"Requirement already satisfied: multiprocess in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from datasets) (0.70.14)\n",
"Requirement already satisfied: fsspec[http]<2023.9.0,>=2023.1.0 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from datasets) (2023.6.0)\n",
"Requirement already satisfied: aiohttp in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from datasets) (3.8.5)\n",
"Requirement already satisfied: huggingface-hub<1.0.0,>=0.14.0 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from datasets) (0.17.2)\n",
"Requirement already satisfied: packaging in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from datasets) (23.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from datasets) (6.0.1)\n",
"Requirement already satisfied: regex!=2019.12.17 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from transformers) (2023.8.8)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from transformers) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from transformers) (0.3.3)\n",
"Requirement already satisfied: responses<0.19 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from evaluate) (0.18.0)\n",
"Requirement already satisfied: attrs>=17.3.0 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from aiohttp->datasets) (23.1.0)\n",
"Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from aiohttp->datasets) (2.0.4)\n",
"Requirement already satisfied: multidict<7.0,>=4.5 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from aiohttp->datasets) (6.0.4)\n",
"Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from aiohttp->datasets) (4.0.3)\n",
"Requirement already satisfied: yarl<2.0,>=1.0 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from aiohttp->datasets) (1.9.2)\n",
"Requirement already satisfied: frozenlist>=1.1.1 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from aiohttp->datasets) (1.4.0)\n",
"Requirement already satisfied: aiosignal>=1.1.2 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from aiohttp->datasets) (1.3.1)\n",
"Requirement already satisfied: idna<4,>=2.5 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from requests>=2.19.0->datasets) (3.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from requests>=2.19.0->datasets) (1.26.16)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from requests>=2.19.0->datasets) (2023.7.22)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from jinja2->torch==2.0.0) (2.1.1)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from pandas->datasets) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from pandas->datasets) (2023.3.post1)\n",
"Requirement already satisfied: tzdata>=2022.1 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from pandas->datasets) (2023.3)\n",
"Requirement already satisfied: mpmath>=0.19 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from sympy->torch==2.0.0) (1.3.0)\n",
"Requirement already satisfied: six>=1.5 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n",
"Collecting git+https://github.com/huggingface/accelerate\n",
" Cloning https://github.com/huggingface/accelerate to /tmp/pip-req-build-33y4srd1\n",
" Running command git clone --filter=blob:none --quiet https://github.com/huggingface/accelerate /tmp/pip-req-build-33y4srd1\n",
" Resolved https://github.com/huggingface/accelerate to commit 629d02c8446354860c9bdf58b6bc006186cbc818\n",
" Installing build dependencies ... \u001b[?25ldone\n",
"\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n",
"\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n",
"\u001b[?25hRequirement already satisfied: numpy>=1.17 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from accelerate==0.24.0.dev0) (1.25.2)\n",
"Requirement already satisfied: packaging>=20.0 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from accelerate==0.24.0.dev0) (23.1)\n",
"Requirement already satisfied: psutil in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from accelerate==0.24.0.dev0) (5.9.5)\n",
"Requirement already satisfied: pyyaml in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from accelerate==0.24.0.dev0) (6.0.1)\n",
"Requirement already satisfied: torch>=1.10.0 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from accelerate==0.24.0.dev0) (2.0.0)\n",
"Requirement already satisfied: huggingface-hub in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from accelerate==0.24.0.dev0) (0.17.2)\n",
"Requirement already satisfied: filelock in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch>=1.10.0->accelerate==0.24.0.dev0) (3.9.0)\n",
"Requirement already satisfied: typing-extensions in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch>=1.10.0->accelerate==0.24.0.dev0) (4.7.1)\n",
"Requirement already satisfied: sympy in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch>=1.10.0->accelerate==0.24.0.dev0) (1.11.1)\n",
"Requirement already satisfied: networkx in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch>=1.10.0->accelerate==0.24.0.dev0) (3.1)\n",
"Requirement already satisfied: jinja2 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch>=1.10.0->accelerate==0.24.0.dev0) (3.1.2)\n",
"Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch>=1.10.0->accelerate==0.24.0.dev0) (11.7.99)\n",
"Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch>=1.10.0->accelerate==0.24.0.dev0) (11.7.99)\n",
"Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch>=1.10.0->accelerate==0.24.0.dev0) (11.7.101)\n",
"Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch>=1.10.0->accelerate==0.24.0.dev0) (8.5.0.96)\n",
"Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch>=1.10.0->accelerate==0.24.0.dev0) (11.10.3.66)\n",
"Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch>=1.10.0->accelerate==0.24.0.dev0) (10.9.0.58)\n",
"Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch>=1.10.0->accelerate==0.24.0.dev0) (10.2.10.91)\n",
"Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch>=1.10.0->accelerate==0.24.0.dev0) (11.4.0.1)\n",
"Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch>=1.10.0->accelerate==0.24.0.dev0) (11.7.4.91)\n",
"Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch>=1.10.0->accelerate==0.24.0.dev0) (2.14.3)\n",
"Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch>=1.10.0->accelerate==0.24.0.dev0) (11.7.91)\n",
"Requirement already satisfied: triton==2.0.0 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from torch>=1.10.0->accelerate==0.24.0.dev0) (2.0.0)\n",
"Requirement already satisfied: setuptools in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.10.0->accelerate==0.24.0.dev0) (68.0.0)\n",
"Requirement already satisfied: wheel in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.10.0->accelerate==0.24.0.dev0) (0.38.4)\n",
"Requirement already satisfied: cmake in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from triton==2.0.0->torch>=1.10.0->accelerate==0.24.0.dev0) (3.27.4.1)\n",
"Requirement already satisfied: lit in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from triton==2.0.0->torch>=1.10.0->accelerate==0.24.0.dev0) (16.0.6)\n",
"Requirement already satisfied: fsspec in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from huggingface-hub->accelerate==0.24.0.dev0) (2023.6.0)\n",
"Requirement already satisfied: requests in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from huggingface-hub->accelerate==0.24.0.dev0) (2.31.0)\n",
"Requirement already satisfied: tqdm>=4.42.1 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from huggingface-hub->accelerate==0.24.0.dev0) (4.66.1)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from jinja2->torch>=1.10.0->accelerate==0.24.0.dev0) (2.1.1)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from requests->huggingface-hub->accelerate==0.24.0.dev0) (2.0.4)\n",
"Requirement already satisfied: idna<4,>=2.5 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from requests->huggingface-hub->accelerate==0.24.0.dev0) (3.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from requests->huggingface-hub->accelerate==0.24.0.dev0) (1.26.16)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from requests->huggingface-hub->accelerate==0.24.0.dev0) (2023.7.22)\n",
"Requirement already satisfied: mpmath>=0.19 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from sympy->torch>=1.10.0->accelerate==0.24.0.dev0) (1.3.0)\n",
"Requirement already satisfied: scikit-learn in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (1.3.0)\n",
"Requirement already satisfied: numpy>=1.17.3 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from scikit-learn) (1.25.2)\n",
"Requirement already satisfied: scipy>=1.5.0 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from scikit-learn) (1.11.2)\n",
"Requirement already satisfied: joblib>=1.1.1 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from scikit-learn) (1.3.2)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages (from scikit-learn) (3.2.0)\n"
]
}
],
"source": [
"! pip install datasets transformers evaluate torch==2.0.0\n",
"! pip install git+https://github.com/huggingface/accelerate\n",
"! pip install scikit-learn\n",
"\n",
"## colab specific \n",
"#! pip install cloud-tpu-client==0.10 \n",
"#! pip install https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-2.0-cp310-cp310-linux_x86_64.whl\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R1FESu6IiqOV"
},
"source": [
"Here are all the imports we will need for this notebook."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4rvtD4_mskaW",
"outputId": "4adf87e2-2eea-4845-f3fc-af9b239ea95c"
},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import DataLoader\n",
"\n",
"from accelerate import Accelerator, DistributedType\n",
"from datasets import load_dataset, load_metric\n",
"from transformers import (\n",
" AdamW,\n",
" AutoModelForSequenceClassification,\n",
" AutoTokenizer,\n",
" get_linear_schedule_with_warmup,\n",
" set_seed,\n",
")\n",
"\n",
"from tqdm.auto import tqdm\n",
"\n",
"import datasets\n",
"import transformers"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Copy-and-paste the text below in your GitHub issue\n",
"\n",
"- `Accelerate` version: 0.24.0.dev0\n",
"- Platform: Linux-5.15.0-83-generic-x86_64-with-glibc2.35\n",
"- Python version: 3.9.17\n",
"- Numpy version: 1.25.2\n",
"- PyTorch version (GPU?): 2.0.0+cu117 (True)\n",
"- PyTorch XPU available: False\n",
"- PyTorch NPU available: False\n",
"- System RAM: 62.79 GB\n",
"- GPU type: Tesla T4\n",
"- `Accelerate` default config:\n",
"\t- compute_environment: LOCAL_MACHINE\n",
"\t- distributed_type: MULTI_GPU\n",
"\t- mixed_precision: no\n",
"\t- use_cpu: False\n",
"\t- debug: False\n",
"\t- num_processes: 3\n",
"\t- machine_rank: 0\n",
"\t- num_machines: 1\n",
"\t- gpu_ids: all\n",
"\t- rdzv_backend: static\n",
"\t- same_network: True\n",
"\t- main_training_function: main\n",
"\t- downcast_bf16: no\n",
"\t- tpu_use_cluster: False\n",
"\t- tpu_use_sudo: False\n",
"\t- tpu_env: []\n"
]
}
],
"source": [
"## version of accelerate\n",
"\n",
"! accelerate env"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3_Daio_mlZtt"
},
"source": [
"This notebook can run with any model checkpoint on the [model hub](https://huggingface.co/models) that has a version with a classification head. Here we select [`bert-base-cased`](https://huggingface.co/bert-base-cased)."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "aUWu0U_plwMD"
},
"outputs": [],
"source": [
"model_checkpoint = \"bert-base-cased\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "AtBUDbklowDi"
},
"source": [
"The next two sections explain how we load and prepare our data for our model, If you are only interested on seeing how 🤗 Accelerate works, feel free to skip them (but make sure to execute all cells!)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "e9LqtHw7otJl"
},
"source": [
"## Load the data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NKjBrsTGiyf3"
},
"source": [
"To load the dataset, we use the `load_dataset` function from 🤗 Datasets. It will download and cache it (so the download won't happen if we restart the notebook)."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "YpUUqTXQiwur",
"outputId": "f53cd1cd-48c6-4f5d-a36c-0e7b2a29d7b7"
},
"outputs": [],
"source": [
"raw_datasets = load_dataset(\"glue\", \"mrpc\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P_j3Tnd6J_4d"
},
"source": [
"The `raw_datasets` object itself is [`DatasetDict`](https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasetdict), which contains one key for the training, validation and test set (with more keys for the mismatched validation and test set in the special case of `mnli`).\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Be6GMC1PKIts",
"outputId": "0a5966cf-7f32-49bd-bfb7-92dde0c529c4"
},
"outputs": [
{
"data": {
"text/plain": [
"DatasetDict({\n",
" train: Dataset({\n",
" features: ['sentence1', 'sentence2', 'label', 'idx'],\n",
" num_rows: 3668\n",
" })\n",
" validation: Dataset({\n",
" features: ['sentence1', 'sentence2', 'label', 'idx'],\n",
" num_rows: 408\n",
" })\n",
" test: Dataset({\n",
" features: ['sentence1', 'sentence2', 'label', 'idx'],\n",
" num_rows: 1725\n",
" })\n",
"})"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"raw_datasets"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tPT8uVIrKLxm"
},
"source": [
"To access an actual element, you need to select a split first, then give an index:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "xl-dxrpbKNdz",
"outputId": "c3ded223-4687-4d62-919e-2b2c9293f589"
},
"outputs": [
{
"data": {
"text/plain": [
"{'sentence1': 'Amrozi accused his brother , whom he called \" the witness \" , of deliberately distorting his evidence .',\n",
" 'sentence2': 'Referring to him as only \" the witness \" , Amrozi accused his brother of deliberately distorting his evidence .',\n",
" 'label': 1,\n",
" 'idx': 0}"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"raw_datasets[\"train\"][0]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2YAh1I2tKPJU"
},
"source": [
"To get a sense of what the data looks like, the following function will show some examples picked randomly in the dataset."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 495
},
"id": "Yklu4bOuKReC",
"outputId": "692d395c-9386-4f56-cc06-a6fb2cf64443"
},
"outputs": [
{
"data": {
"text/html": [
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>sentence1</th>\n",
" <th>sentence2</th>\n",
" <th>label</th>\n",
" <th>idx</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>The tech-heavy Nasdaq composite index fell 3.99 , or 0.2 percent , to 1,682.72 , following a two-day win of 55.93 .</td>\n",
" <td>The technology-laced Nasdaq Composite Index .IXIC eased 8.52 points , or 0.51 percent , to 1,670.21 .</td>\n",
" <td>not_equivalent</td>\n",
" <td>1146</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>But he will meet the French President , Jacques Chirac , privately .</td>\n",
" <td>They include French President Jacques Chirac and German Chancellor Gerhard Schroeder .</td>\n",
" <td>not_equivalent</td>\n",
" <td>504</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Feelings about current business conditions improved substantially from the first quarter , jumping from 40 to 55 .</td>\n",
" <td>Assessment of current business conditions improved substantially , the Conference Board said , jumping to 55 from 40 in the first quarter .</td>\n",
" <td>equivalent</td>\n",
" <td>4054</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>The moment of reckoning has arrived for this West African country founded by freed American slaves in the 19th century .</td>\n",
" <td>Taylor is now expected to leave the broken shell of a nation founded by freed American slaves in the 19th century .</td>\n",
" <td>not_equivalent</td>\n",
" <td>365</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>All three were studied for fingerprints , DNA and other traces of evidence , but prosecutors have not yet testified to what , if anything , they yielded .</td>\n",
" <td>All three were studied for fingerprints , DNA and other traces of evidence , but there has been no testimony yet about what the tests might have yielded .</td>\n",
" <td>equivalent</td>\n",
" <td>1008</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>Selenski 's partner in the jailbreak , Scott Bolton , injured his ankle , pelvis and ribs during the escape attempt , Warden Gene Fischi said .</td>\n",
" <td>Selenski 's partner in the Friday jailbreak , Scott Bolton , was injured in the escape and hospitalized .</td>\n",
" <td>equivalent</td>\n",
" <td>3566</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>Before completion , the group will take surplus cash of 16.5m from TCG to reduce its net borrowings .</td>\n",
" <td>Prior to completion , CCG said it will also extract surplus cash of $ 27 million to reduce net borrowings .</td>\n",
" <td>equivalent</td>\n",
" <td>3049</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>This is what Dr. Dean said : \" I still want to be the candidate for guys with Confederate flags in their pickup trucks .</td>\n",
" <td>He told the Register : \" I still want to be the candidate for guys with Confederate flags in their pickup trucks . \"</td>\n",
" <td>equivalent</td>\n",
" <td>3247</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>Costner plays Charlie Waite , a cowboy with a violent past .</td>\n",
" <td>In \" Open Range , \" Costner plays a cowboy who works with cattleman Robert Duvall .</td>\n",
" <td>not_equivalent</td>\n",
" <td>3206</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>Moose frequently criticizes reporters and news organizations in the book , especially those that reported on leaks from investigators .</td>\n",
" <td>He frequently criticizes the press , especially reporters and news organizations that reported on leaks from investigators .</td>\n",
" <td>equivalent</td>\n",
" <td>3536</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import datasets\n",
"import random\n",
"import pandas as pd\n",
"from IPython.display import display, HTML\n",
"\n",
"def show_random_elements(dataset, num_examples=10):\n",
" assert num_examples <= len(dataset), \"Can't pick more elements than there are in the dataset.\"\n",
" picks = []\n",
" for _ in range(num_examples):\n",
" pick = random.randint(0, len(dataset)-1)\n",
" while pick in picks:\n",
" pick = random.randint(0, len(dataset)-1)\n",
" picks.append(pick)\n",
" \n",
" df = pd.DataFrame(dataset[picks])\n",
" for column, typ in dataset.features.items():\n",
" if isinstance(typ, datasets.ClassLabel):\n",
" df[column] = df[column].transform(lambda i: typ.names[i])\n",
" display(HTML(df.to_html()))\n",
"\n",
"show_random_elements(raw_datasets[\"train\"])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a_FKFTuBKm8J"
},
"source": [
"## Preprocess the data"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6StNGqH8KqLG"
},
"source": [
"Before we can feed those texts to our model, we need to preprocess them. This is done by a 🤗 Transformers `Tokenizer` which will (as the name indicates) tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that model requires.\n",
"\n",
"To do all of this, we instantiate our tokenizer with the `AutoTokenizer.from_pretrained` method, which will ensure:\n",
"\n",
"we get a tokenizer that corresponds to the model architecture we want to use,\n",
"we download the vocabulary used when pretraining this specific checkpoint.\n",
"That vocabulary will be cached, so it's not downloaded again the next time we run the cell."
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "rPmZCsmdKsAG"
},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QdUdoLF6K8Vk"
},
"source": [
"By default (unless you pass `use_fast=Fast` to the call above) it will use one of the fast tokenizers (backed by Rust) from the 🤗 Tokenizers library. Those fast tokenizers are available for almost all models, but if you got an error with the previous call, remove that argument.\n",
"\n",
"You can directly call this tokenizer on one sentence or a pair of sentences:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "yNuR5fCYLCr9",
"outputId": "d903f051-b4c3-407c-9b54-d7a377675999"
},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': [101, 8667, 117, 1142, 1141, 5650, 106, 102, 1262, 1142, 5650, 2947, 1114, 1122, 119, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenizer(\"Hello, this one sentence!\", \"And this sentence goes with it.\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "I6YulfZrLF7C"
},
"source": [
"Depending on the model you selected, you will see different keys in the dictionary returned by the cell above. They don't matter much for what we're doing here (just know they are required by the model we will instantiate later), you can learn more about them in [this tutorial](https://huggingface.co/transformers/preprocessing.html) if you're interested.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LDG_YTsBLSZi"
},
"source": [
"We can them write the function that will preprocess our samples. We just feed them to the `tokenizer` with the argument `truncation=True`. We also need all of our samples to have the same length (we will train on TPU and they need fixed shapes so we won't pad to the maximum length of a batch) which is done with `padding=True`. The `max_length` argument is used both for the truncation and padding (short inputs are padded to that length and long inputs are truncated to it).\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "sDZmE_5cLUqE"
},
"outputs": [],
"source": [
"def tokenize_function(examples):\n",
" outputs = tokenizer(examples[\"sentence1\"], examples[\"sentence2\"], truncation=True, padding=\"max_length\", max_length=128)\n",
" return outputs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SdPLbQGWLYAJ"
},
"source": [
"This function works with one or several examples. In the case of several examples, the tokenizer will return a list of lists for each key:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "U3Th9LBtLaGb",
"outputId": "c5938803-b758-4486-f8ff-b5d9a59ca31e"
},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': [[101, 7277, 2180, 5303, 4806, 1117, 1711, 117, 2292, 1119, 1270, 107, 1103, 7737, 107, 117, 1104, 9938, 4267, 12223, 21811, 1117, 2554, 119, 102, 11336, 6732, 3384, 1106, 1140, 1112, 1178, 107, 1103, 7737, 107, 117, 7277, 2180, 5303, 4806, 1117, 1711, 1104, 9938, 4267, 12223, 21811, 1117, 2554, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 10684, 2599, 9717, 1161, 2205, 11288, 1377, 112, 188, 1196, 4147, 1103, 4129, 1106, 19770, 2787, 1107, 1772, 1111, 109, 123, 119, 126, 3775, 119, 102, 10684, 2599, 9717, 1161, 3306, 11288, 1377, 112, 188, 1107, 1876, 1111, 109, 5691, 1495, 1550, 1105, 1962, 1122, 1106, 19770, 2787, 1111, 109, 122, 119, 129, 3775, 1107, 1772, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 1220, 1125, 1502, 1126, 16355, 1113, 1103, 4639, 1113, 1340, 1275, 117, 4733, 1103, 6527, 1111, 4688, 117, 1119, 1896, 119, 102, 1212, 1340, 1275, 117, 1103, 2062, 112, 188, 5032, 1125, 1502, 1126, 16355, 1113, 1103, 4639, 117, 4733, 1103, 16454, 1111, 4688, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 5596, 5347, 19297, 14748, 1942, 117, 22515, 1830, 6117, 1127, 1146, 1627, 18748, 117, 1137, 125, 119, 125, 110, 117, 1120, 138, 109, 125, 119, 4376, 117, 1515, 2206, 1383, 170, 1647, 1344, 1104, 138, 109, 125, 119, 4667, 119, 102, 22515, 1830, 6117, 4874, 1406, 18748, 117, 1137, 125, 119, 127, 110, 117, 1106, 1383, 170, 1647, 5134, 1344, 1120, 138, 109, 125, 119, 4667, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 1109, 4482, 3152, 109, 123, 119, 1429, 117, 1137, 1164, 1429, 3029, 117, 1106, 1601, 5286, 1120, 109, 1626, 119, 4062, 1113, 1103, 1203, 1365, 9924, 7855, 119, 102, 153, 2349, 111, 142, 13619, 119, 6117, 4874, 109, 122, 119, 5519, 1137, 129, 3029, 1106, 109, 1626, 119, 5347, 1113, 1103, 1203, 1365, 9924, 7855, 1113, 5286, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]}"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenize_function(raw_datasets['train'][:5])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "W_ZAk-uQLdlG"
},
"source": [
"To apply this function on all the sentences (or pairs of sentences) in our dataset, we just use the `map` method of our `dataset` object we created earlier. This will apply the function on all the elements of all the splits in `dataset`, so our training, validation and testing data will be preprocessed in one single command.\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 164,
"referenced_widgets": [
"54fa72d8df5d471688bf58ce8d7d994d",
"10cbaf06eb6e4dab836e92e51e543e1a",
"9c95e32846454e07987c69b53b49aedd",
"9a9e13893ca748d98077333b15b466ac",
"f65d00a348454dfbad7ca4d3ffee1064",
"c1de1bf5955943b78e29e15c47df291c",
"45dbedeedfb7447ea942b231091bb29f",
"2ec7de6c4e16478f80e7802ab8e5a2e4",
"6c23a5fe1ed549329b15431817918f8d",
"0a53a19926ec44b495fe0108b8a53d5b",
"f0be5d04d3f1495585ccc6b46619a0bf",
"1c7ee04d39714ca597fef65a65d8acdf",
"33e6e73d07dd41cdba8282a0b7b2e17a",
"c34061fd048741c7947d89bbe5daafa5",
"6ae30e0fa7224d59ae3819fd89527770",
"b92465e23aa94cbb92fe0631bc5ba2e0",
"8dfc01750fb149bb837d8a2c51fc4198",
"8efcba41f8f049f4bdc30c0416d52ad8",
"31c3257c5d1d4120a1030c385b47ca45",
"96f2dce18ed84dd799efb1780c7d81de",
"480b04299b9f405e8f9a0f621451d8d4",
"24f7e1b9446b43e8b17b4c54546ba20d",
"b4757b916a094c76916e64d7841d2245",
"c00a3e9d092f43e88d1bb9bf1c680bb9"
]
},
"id": "cphvac5ILezw",
"outputId": "84b4a0fa-c727-4f73-bff8-0b9e7f2e22b1"
},
"outputs": [],
"source": [
"tokenized_datasets = raw_datasets.map(tokenize_function, batched=True, remove_columns=[\"idx\", \"sentence1\", \"sentence2\"])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "H4FaM4ZGLgad"
},
"source": [
"Even better, the results are automatically cached by the 🤗 Datasets library to avoid spending time on this step the next time you run your notebook. The 🤗 Datasets library is normally smart enough to detect when the function you pass to map has changed (and thus requires to not use the cache data). For instance, it will properly detect if you change the task in the first cell and rerun the notebook. 🤗 Datasets warns you when it uses cached files, you can pass `load_from_cache_file=False` in the call to `map` to not use the cached files and force the preprocessing to be applied again.\n",
"\n",
"Note that we passed `batched=True` to encode the texts by batches together. This is to leverage the full benefit of the fast tokenizer we loaded earlier, which will use multi-threading to treat the texts in a batch concurrently.\n",
"\n",
"Lastly, we remove the columns that our model will not use. We also need to rename the `label` column to `labels` as this is what our model will expect."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"id": "vfbSmdoWk5rG"
},
"outputs": [],
"source": [
"tokenized_datasets = tokenized_datasets.rename_column(\"label\", \"labels\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Hlw6EGsXlIHv"
},
"source": [
"To double-check we only have columns that are accepted as arguments for the model we will instantiate, we can look at them here."
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TmqrhXwYlD4l",
"outputId": "173f8b3f-d1b0-46f0-f072-b886614d55a6"
},
"outputs": [
{
"data": {
"text/plain": [
"{'labels': ClassLabel(names=['not_equivalent', 'equivalent'], id=None),\n",
" 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),\n",
" 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),\n",
" 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)}"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenized_datasets[\"train\"].features"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s_2D3L6kn7fn"
},
"source": [
"The model we will be using is a `BertModelForSequenceClassification`. We can check its signature in the [Transformers documentation](https://huggingface.co/transformers/model_doc/bert.html#transformers.BertForSequenceClassification) and all seems to be right! The last step is to set our datasets in the `\"torch\"` format, so that each item in it is now a dictionary with tensor values."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"id": "Qy1VJb2dsJus"
},
"outputs": [],
"source": [
"tokenized_datasets.set_format(\"torch\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CSoNVxwUo_jt"
},
"source": [
"## A first look at the model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "e6dcSOA_pzPg"
},
"source": [
"Now that our data is ready, we can download the pretrained model and fine-tune it. Since all our tasks are about sentence classification, we use the `AutoModelForSequenceClassification` class. Like with the tokenizer, the from_pretrained method will download and cache the model for us. The only thing we have to specify is the number of labels for our problem (which is 2 here):"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "tlYXqxvtp-8Y",
"outputId": "6c39c938-a3b3-423c-bc32-2a7c93cc120d"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
}
],
"source": [
"from transformers import AutoModelForSequenceClassification\n",
"\n",
"model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lZpkm-f_qGya"
},
"source": [
"The warning is telling us we are throwing away some weights (the vocab_transform and vocab_layer_norm layers) and randomly initializing some other (the pre_classifier and classifier layers). This is absolutely normal in this case, because we are removing the head used to pretrain the model on a masked language modeling objective and replacing it with a new head for which we don't have pretrained weights, so the library warns us we should fine-tune this model before using it for inference, which is exactly what we are going to do.\n",
"\n",
"Note that we will are only creating the model here to look at it and debug problems. We will create the model we will train inside our training function: to train on TPU in colab, we have to create a big training function that will be executed on each code of the TPU. It's fine to do use the datasets defined before (they will be copied to each TPU core) but the model itself will need to be re-instantiated and placed on each device for it to work.\n",
"\n",
"Now to get the data we need to define our training and evaluation dataloaders. Again, we only create them here for debugging purposes, they will be re-instantiated in our training function, which is why we define a function that builds them."
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"id": "Ll4PjC5nqFqk"
},
"outputs": [],
"source": [
"def create_dataloaders(train_batch_size=8, eval_batch_size=32):\n",
" train_dataloader = DataLoader(\n",
" tokenized_datasets[\"train\"], shuffle=True, batch_size=train_batch_size\n",
" )\n",
" eval_dataloader = DataLoader(\n",
" tokenized_datasets[\"validation\"], shuffle=False, batch_size=eval_batch_size\n",
" )\n",
" return train_dataloader, eval_dataloader"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4a_rrVs0siR7"
},
"source": [
"Let's have a look at our train and evaluation dataloaders to check a batch can go through the model."
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"id": "l2krwQzMrdK2"
},
"outputs": [],
"source": [
"#train_dataloader, eval_dataloader = create_dataloaders()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "dwrNQ1Cbsp5i"
},
"source": [
"We just loop through one batch. Since our datasets elements are dictionaries of tensors, it's the same for our batch and we can have a quick look at all the shapes. Note that this cell takes a bit of time to execute since we run a batch of our data through the model on the CPU (if you changed the checkpoint to a bigger model, it might take too much time so comment it out).\n",
"\n",
"⚠ **WARNING: Running this cell will cause training_function to malfunction, as model will be used before notebook_launcher**"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RA-_yUyrritM",
"outputId": "2717efe0-5081-42af-93a3-3bdd88334a25"
},
"outputs": [],
"source": [
"#for batch in train_dataloader:\n",
"# print({k: v.shape for k, v in batch.items()})\n",
"# outputs = model(**batch)\n",
"# break"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fVCmBAxbtCU_"
},
"source": [
"The output of our model is a `SequenceClassifierOutput`, with the `loss` (since we provided labels) and `logits` (of shape 8, our batch size, by 2, the number of labels)."
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bccUGpZbsXdS",
"outputId": "f79e2882-3813-4340-c2f5-703bc196743a"
},
"outputs": [],
"source": [
"#outputs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-9hiuBqMtU5W"
},
"source": [
"The last piece we will need for the model evaluation is the metric. The `datasets` library provides a function `load_metric` that allows us to easily create a `datasets.Metric` object we can use."
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"id": "brJ6K4oBtpST"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_77293/442538619.py:1: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate\n",
" metric = load_metric(\"glue\", \"mrpc\")\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n",
"To disable this warning, you can either:\n",
"\t- Avoid using `tokenizers` before the fork if possible\n",
"\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
]
}
],
"source": [
"metric = load_metric(\"glue\", \"mrpc\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MG2tjiSEttNO"
},
"source": [
"To use this object on some predictions we call the `compute` methode to get our metric results:"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "YVAPgl1BuB3s",
"outputId": "c92e4117-8111-4006-f2b6-27e9765ed998"
},
"outputs": [],
"source": [
"#predictions = outputs.logits.detach().argmax(dim=-1)\n",
"#metric.compute(predictions=predictions, references=batch[\"labels\"])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rq1_mw0mucQV"
},
"source": [
"Unsurpringly, our model with its random head does not perform well, which is why we need to fine-tune it!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "akoJx8puvTfV"
},
"source": [
"## Fine-tuning the model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "93vWWJa2wVEj"
},
"source": [
"We are now ready to fine-tune this model on our dataset. As mentioned before, everything related to training needs to be in one big training function that will be executed on each TPU core, thanks to our `notebook_launcher`.\n",
"\n",
"It will use this dictionary of hyperparameters, so tweak anything you like in here!"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {
"id": "rqyYx8SK8e0O"
},
"outputs": [],
"source": [
"hyperparameters = {\n",
" \"learning_rate\": 2e-5,\n",
" \"num_epochs\": 3,\n",
" \"train_batch_size\": 8, # Actual batch size will this x 8\n",
" \"eval_batch_size\": 32, # Actual batch size will this x 8\n",
" \"seed\": 42,\n",
"}"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZR1F1G_n863W"
},
"source": [
"The two most important things to remember for training on TPUs is that your accelerator object has to be defined inside your training function, and your model should be created outside the training function. \n",
"\n",
"If you define your Accelerator in another cell that gets executed before the final launch (for debugging), you will need to restart your notebook as the line `accelerator = Accelerator()` needs to be executed for the first time inside the training function spwaned on each TPU core.\n",
"\n",
"This is because that line will look for a TPU device, and if you set it outside of the distributed training launched by `notebook_launcher`, it will perform setup that cannot be undone in your runtime and you will only have access to one TPU core until you restart the notebook.\n",
"\n",
"The reason we declare the model outside the loop is because on a TPU when launched from a notebook the same singular model object is used, and it is passed back and forth between all the cores automatically. \n",
"\n",
"Since we can't explore each piece in separate cells, comments have been left in the code. This is all pretty standard and you will notice how little the code changes from a regular training loop! The main lines added are:\n",
"\n",
"- `accelerator = Accelerator()` to initalize the distributed setup,\n",
"- sending all objects to `accelerator.prepare`,\n",
"- replace `loss.backward()` with `accelerator.backward(loss)`,\n",
"- use `accelerator.gather` to gather all predictions and labels before storing them in our list of predictions/labels,\n",
"- truncate predictions and labels as the prepared evaluation dataloader has a few more samples to make batches of the same size on each process.\n",
"\n",
"The first three are for distributed training, the last two for distributed evaluation. If you don't care about distributed evaluation, you can also just replace that part by your standard evaluation loop launched on the main process only.\n",
"\n",
"Other changes (which are purely cosmetic to make the output of the training readable) are:\n",
"\n",
"- some logging behavior behind a `if accelerator.is_main_process:`,\n",
"- disable the progress bar if `accelerator.is_main_process` is `False`,\n",
"- use `accelerator.print` instead of `print`."
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {
"id": "QueNpNJF9eGq"
},
"outputs": [],
"source": [
"def training_function(model):\n",
" # Initialize accelerator\n",
" accelerator = Accelerator()\n",
"\n",
" # To have only one message (and not 8) per logs of Transformers or Datasets, we set the logging verbosity\n",
" # to INFO for the main process only.\n",
" if accelerator.is_main_process:\n",
" datasets.utils.logging.set_verbosity_warning()\n",
" transformers.utils.logging.set_verbosity_info()\n",
" else:\n",
" datasets.utils.logging.set_verbosity_error()\n",
" transformers.utils.logging.set_verbosity_error()\n",
"\n",
" train_dataloader, eval_dataloader = create_dataloaders(\n",
" train_batch_size=hyperparameters[\"train_batch_size\"], eval_batch_size=hyperparameters[\"eval_batch_size\"]\n",
" )\n",
" # The seed need to be set before we instantiate the model, as it will determine the random head.\n",
" set_seed(hyperparameters[\"seed\"])\n",
"\n",
" # Instantiate optimizer\n",
" optimizer = AdamW(params=model.parameters(), lr=hyperparameters[\"learning_rate\"])\n",
"\n",
" # Prepare everything\n",
" # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the\n",
" # prepare method.\n",
" model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(\n",
" model, optimizer, train_dataloader, eval_dataloader\n",
" )\n",
"\n",
" num_epochs = hyperparameters[\"num_epochs\"]\n",
" # Instantiate learning rate scheduler after preparing the training dataloader as the prepare method\n",
" # may change its length.\n",
" lr_scheduler = get_linear_schedule_with_warmup(\n",
" optimizer=optimizer,\n",
" num_warmup_steps=100,\n",
" num_training_steps=len(train_dataloader) * num_epochs,\n",
" )\n",
"\n",
" # Instantiate a progress bar to keep track of training. Note that we only enable it on the main\n",
" # process to avoid having 8 progress bars.\n",
" progress_bar = tqdm(range(num_epochs * len(train_dataloader)), disable=not accelerator.is_main_process)\n",
" # Now we train the model\n",
" for epoch in range(num_epochs):\n",
" model.train()\n",
" for step, batch in enumerate(train_dataloader):\n",
" outputs = model(**batch)\n",
" loss = outputs.loss\n",
" accelerator.backward(loss)\n",
" \n",
" optimizer.step()\n",
" lr_scheduler.step()\n",
" optimizer.zero_grad()\n",
" progress_bar.update(1)\n",
"\n",
" model.eval()\n",
" all_predictions = []\n",
" all_labels = []\n",
"\n",
" for step, batch in enumerate(eval_dataloader):\n",
" with torch.no_grad():\n",
" outputs = model(**batch)\n",
" predictions = outputs.logits.argmax(dim=-1)\n",
"\n",
" # We gather predictions and labels from the 8 TPUs to have them all.\n",
" all_predictions.append(accelerator.gather(predictions))\n",
" all_labels.append(accelerator.gather(batch[\"labels\"]))\n",
"\n",
" # Concatenate all predictions and labels.\n",
" # The last thing we need to do is to truncate the predictions and labels we concatenated\n",
" # together as the prepared evaluation dataloader has a little bit more elements to make\n",
" # batches of the same size on each process.\n",
" all_predictions = torch.cat(all_predictions)[:len(tokenized_datasets[\"validation\"])]\n",
" all_labels = torch.cat(all_labels)[:len(tokenized_datasets[\"validation\"])]\n",
"\n",
" eval_metric = metric.compute(predictions=all_predictions, references=all_labels)\n",
"\n",
" # Use accelerator.print to print only on the main process.\n",
" accelerator.print(f\"epoch {epoch}:\", eval_metric)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "K6LUeuUlEOvL"
},
"source": [
"And we're ready for launch! It's super easy with the `notebook_launcher` from the Accelerate library."
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"from accelerate import notebook_launcher"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 647,
"referenced_widgets": [
"a91fb54c8a3d48d6a094c2a83f34c432",
"55b8e8cb8646431eb67443739529200f",
"07396159436a4202abf23156c1c5a0c6",
"bf64d0787a1b4b5eac8a2db8f3bc69d6",
"7b09dd3fb5e045428091b12bd82877c8",
"ae0fc7ad41264c94898e22090ca22422",
"fe4dd099be15498ea29b1013db02b503",
"f9d76b1743c643939f5a59a30d378a2f"
]
},
"id": "ZBZdzN3lOEuL",
"outputId": "340c815e-c6f0-47f9-e83e-87f98788c2c8"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/demouser/miniconda3/envs/pytorch/lib/python3.9/site-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
" warnings.warn(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c285c4c9be5e4b9dba8126ba29b67c55",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
" 0%| | 0/1377 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch 0: {'accuracy': 0.7965686274509803, 'f1': 0.8369351669941061}\n",
"epoch 1: {'accuracy': 0.8333333333333334, 'f1': 0.8896103896103896}\n",
"epoch 2: {'accuracy': 0.8651960784313726, 'f1': 0.9050086355785838}\n"
]
},
{
"ename": "ValueError",
"evalue": "To launch a multi-GPU training from your notebook, the `Accelerator` should only be initialized inside your training function. Restart your notebook and make sure no cells initializes an `Accelerator`.",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[29], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mnotebook_launcher\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtraining_function\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_processes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m3\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;66;03m#notebook_launcher(training_function, (model,), num_processes=3)\u001b[39;00m\n",
"File \u001b[0;32m~/miniconda3/envs/pytorch/lib/python3.9/site-packages/accelerate/launchers.py:140\u001b[0m, in \u001b[0;36mnotebook_launcher\u001b[0;34m(function, args, num_processes, mixed_precision, use_port, master_addr, node_rank, num_nodes)\u001b[0m\n\u001b[1;32m 137\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mmultiprocessing\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mspawn\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m ProcessRaisedException\n\u001b[1;32m 139\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(AcceleratorState\u001b[38;5;241m.\u001b[39m_shared_state) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m--> 140\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\n\u001b[1;32m 141\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTo launch a multi-GPU training from your notebook, the `Accelerator` should only be initialized \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 142\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124minside your training function. Restart your notebook and make sure no cells initializes an \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 143\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m`Accelerator`.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 144\u001b[0m )\n\u001b[1;32m 145\u001b[0m \u001b[38;5;66;03m# torch.distributed will expect a few environment variable to be here. We set the ones common to each\u001b[39;00m\n\u001b[1;32m 146\u001b[0m \u001b[38;5;66;03m# process here (the other ones will be set be the launcher).\u001b[39;00m\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m patch_environment(\n\u001b[1;32m 148\u001b[0m nproc\u001b[38;5;241m=\u001b[39mnum_processes,\n\u001b[1;32m 149\u001b[0m node_rank\u001b[38;5;241m=\u001b[39mnode_rank,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 154\u001b[0m ):\n\u001b[1;32m 155\u001b[0m \u001b[38;5;66;03m# First dummy launch\u001b[39;00m\n",
"\u001b[0;31mValueError\u001b[0m: To launch a multi-GPU training from your notebook, the `Accelerator` should only be initialized inside your training function. Restart your notebook and make sure no cells initializes an `Accelerator`."
]
}
],
"source": [
"notebook_launcher(training_function(model), num_processes=3)\n",
"#notebook_launcher(training_function, (model,), num_processes=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Qaa-dIlEvCwY"
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "TPU",
"colab": {
"collapsed_sections": [],
"name": "Simple NLP Example",
"provenance": []
},
"kernelspec": {
"display_name": "pytorch",
"language": "python",
"name": "pytorch"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"07396159436a4202abf23156c1c5a0c6": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "100%",
"description_tooltip": null,
"layout": "IPY_MODEL_ae0fc7ad41264c94898e22090ca22422",
"max": 174,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_7b09dd3fb5e045428091b12bd82877c8",
"value": 174
}
},
"0a53a19926ec44b495fe0108b8a53d5b": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"10cbaf06eb6e4dab836e92e51e543e1a": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"1c7ee04d39714ca597fef65a65d8acdf": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_b92465e23aa94cbb92fe0631bc5ba2e0",
"placeholder": "​",
"style": "IPY_MODEL_6ae30e0fa7224d59ae3819fd89527770",
"value": " 1/1 [00:00&lt;00:00, 2.72ba/s]"
}
},
"24f7e1b9446b43e8b17b4c54546ba20d": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"2ec7de6c4e16478f80e7802ab8e5a2e4": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"31c3257c5d1d4120a1030c385b47ca45": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "100%",
"description_tooltip": null,
"layout": "IPY_MODEL_24f7e1b9446b43e8b17b4c54546ba20d",
"max": 2,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_480b04299b9f405e8f9a0f621451d8d4",
"value": 2
}
},
"33e6e73d07dd41cdba8282a0b7b2e17a": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"45dbedeedfb7447ea942b231091bb29f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"480b04299b9f405e8f9a0f621451d8d4": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"54fa72d8df5d471688bf58ce8d7d994d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_9c95e32846454e07987c69b53b49aedd",
"IPY_MODEL_9a9e13893ca748d98077333b15b466ac"
],
"layout": "IPY_MODEL_10cbaf06eb6e4dab836e92e51e543e1a"
}
},
"55b8e8cb8646431eb67443739529200f": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"6ae30e0fa7224d59ae3819fd89527770": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"6c23a5fe1ed549329b15431817918f8d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_f0be5d04d3f1495585ccc6b46619a0bf",
"IPY_MODEL_1c7ee04d39714ca597fef65a65d8acdf"
],
"layout": "IPY_MODEL_0a53a19926ec44b495fe0108b8a53d5b"
}
},
"7b09dd3fb5e045428091b12bd82877c8": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"8dfc01750fb149bb837d8a2c51fc4198": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_31c3257c5d1d4120a1030c385b47ca45",
"IPY_MODEL_96f2dce18ed84dd799efb1780c7d81de"
],
"layout": "IPY_MODEL_8efcba41f8f049f4bdc30c0416d52ad8"
}
},
"8efcba41f8f049f4bdc30c0416d52ad8": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"96f2dce18ed84dd799efb1780c7d81de": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_c00a3e9d092f43e88d1bb9bf1c680bb9",
"placeholder": "​",
"style": "IPY_MODEL_b4757b916a094c76916e64d7841d2245",
"value": " 2/2 [00:09&lt;00:00, 4.66s/ba]"
}
},
"9a9e13893ca748d98077333b15b466ac": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_2ec7de6c4e16478f80e7802ab8e5a2e4",
"placeholder": "​",
"style": "IPY_MODEL_45dbedeedfb7447ea942b231091bb29f",
"value": " 4/4 [00:10&lt;00:00, 2.56s/ba]"
}
},
"9c95e32846454e07987c69b53b49aedd": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "100%",
"description_tooltip": null,
"layout": "IPY_MODEL_c1de1bf5955943b78e29e15c47df291c",
"max": 4,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_f65d00a348454dfbad7ca4d3ffee1064",
"value": 4
}
},
"a91fb54c8a3d48d6a094c2a83f34c432": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_07396159436a4202abf23156c1c5a0c6",
"IPY_MODEL_bf64d0787a1b4b5eac8a2db8f3bc69d6"
],
"layout": "IPY_MODEL_55b8e8cb8646431eb67443739529200f"
}
},
"ae0fc7ad41264c94898e22090ca22422": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"b4757b916a094c76916e64d7841d2245": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"b92465e23aa94cbb92fe0631bc5ba2e0": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"bf64d0787a1b4b5eac8a2db8f3bc69d6": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_f9d76b1743c643939f5a59a30d378a2f",
"placeholder": "​",
"style": "IPY_MODEL_fe4dd099be15498ea29b1013db02b503",
"value": " 174/174 [02:28&lt;00:00, 1.47it/s]"
}
},
"c00a3e9d092f43e88d1bb9bf1c680bb9": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"c1de1bf5955943b78e29e15c47df291c": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"c34061fd048741c7947d89bbe5daafa5": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"f0be5d04d3f1495585ccc6b46619a0bf": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "100%",
"description_tooltip": null,
"layout": "IPY_MODEL_c34061fd048741c7947d89bbe5daafa5",
"max": 1,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_33e6e73d07dd41cdba8282a0b7b2e17a",
"value": 1
}
},
"f65d00a348454dfbad7ca4d3ffee1064": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": "initial"
}
},
"f9d76b1743c643939f5a59a30d378a2f": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"fe4dd099be15498ea29b1013db02b503": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
}
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment