Last active
February 23, 2020 18:50
-
-
Save ceshine/1fe607938c4caef863dab056b0c0048b to your computer and use it in GitHub Desktop.
Train huggingface/transformers BERT model on Colab CPU with TF 2.1
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.5" | |
}, | |
"colab": { | |
"name": "63ddc9198a1976512bef5e1b92610ea3", | |
"provenance": [], | |
"include_colab_link": true | |
}, | |
"accelerator": "TPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/ceshine/1fe607938c4caef863dab056b0c0048b/run_tf_glue.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "1m94kz61TKbh", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Adapted from [transformers/examples/run_tf_glue.py](https://github.com/huggingface/transformers/blob/master/examples/run_tf_glue.py)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "rBTOd3i6TYb_", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 607 | |
}, | |
"outputId": "6c33758d-9a07-4149-ea71-950888ff6a33" | |
}, | |
"source": [ | |
"%tensorflow_version 2.x\n", | |
"!pip install transformers==2.3.0" | |
], | |
"execution_count": 1, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"TensorFlow 2.x selected.\n", | |
"Collecting transformers==2.3.0\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/50/10/aeefced99c8a59d828a92cc11d213e2743212d3641c87c82d61b035a7d5c/transformers-2.3.0-py3-none-any.whl (447kB)\n", | |
"\u001b[K |████████████████████████████████| 450kB 3.5MB/s \n", | |
"\u001b[?25hRequirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from transformers==2.3.0) (1.10.47)\n", | |
"Collecting sentencepiece\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/74/f4/2d5214cbf13d06e7cb2c20d84115ca25b53ea76fa1f0ade0e3c9749de214/sentencepiece-0.1.85-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)\n", | |
"\u001b[K |████████████████████████████████| 1.0MB 41.0MB/s \n", | |
"\u001b[?25hRequirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers==2.3.0) (2019.12.20)\n", | |
"Requirement already satisfied: requests in /tensorflow-2.1.0/python3.6 (from transformers==2.3.0) (2.22.0)\n", | |
"Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from transformers==2.3.0) (4.28.1)\n", | |
"Requirement already satisfied: numpy in /tensorflow-2.1.0/python3.6 (from transformers==2.3.0) (1.18.1)\n", | |
"Collecting sacremoses\n", | |
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/a6/b4/7a41d630547a4afd58143597d5a49e07bfd4c42914d8335b2a5657efc14b/sacremoses-0.0.38.tar.gz (860kB)\n", | |
"\u001b[K |████████████████████████████████| 870kB 23.5MB/s \n", | |
"\u001b[?25hRequirement already satisfied: botocore<1.14.0,>=1.13.47 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers==2.3.0) (1.13.47)\n", | |
"Requirement already satisfied: s3transfer<0.3.0,>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers==2.3.0) (0.2.1)\n", | |
"Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers==2.3.0) (0.9.4)\n", | |
"Requirement already satisfied: idna<2.9,>=2.5 in /tensorflow-2.1.0/python3.6 (from requests->transformers==2.3.0) (2.8)\n", | |
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /tensorflow-2.1.0/python3.6 (from requests->transformers==2.3.0) (1.25.7)\n", | |
"Requirement already satisfied: certifi>=2017.4.17 in /tensorflow-2.1.0/python3.6 (from requests->transformers==2.3.0) (2019.11.28)\n", | |
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /tensorflow-2.1.0/python3.6 (from requests->transformers==2.3.0) (3.0.4)\n", | |
"Requirement already satisfied: six in /tensorflow-2.1.0/python3.6 (from sacremoses->transformers==2.3.0) (1.13.0)\n", | |
"Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers==2.3.0) (7.0)\n", | |
"Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers==2.3.0) (0.14.1)\n", | |
"Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.14.0,>=1.13.47->boto3->transformers==2.3.0) (0.15.2)\n", | |
"Requirement already satisfied: python-dateutil<3.0.0,>=2.1; python_version >= \"2.7\" in /usr/local/lib/python3.6/dist-packages (from botocore<1.14.0,>=1.13.47->boto3->transformers==2.3.0) (2.6.1)\n", | |
"Building wheels for collected packages: sacremoses\n", | |
" Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n", | |
" Created wheel for sacremoses: filename=sacremoses-0.0.38-cp36-none-any.whl size=884629 sha256=c44e1b6a02cb9b4c392345d16020e59b7dcabd88422ff5e32b2ca97e15b59ba5\n", | |
" Stored in directory: /root/.cache/pip/wheels/6d/ec/1a/21b8912e35e02741306f35f66c785f3afe94de754a0eaf1422\n", | |
"Successfully built sacremoses\n", | |
"Installing collected packages: sentencepiece, sacremoses, transformers\n", | |
"Successfully installed sacremoses-0.0.38 sentencepiece-0.1.85 transformers-2.3.0\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "g5gigYg_TKbl", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"import os\n", | |
"import json\n", | |
"import math\n", | |
"\n", | |
"import tensorflow as tf\n", | |
"import tensorflow_datasets\n", | |
"\n", | |
"from transformers import (\n", | |
" BertConfig,\n", | |
" BertTokenizer,\n", | |
" TFBertForSequenceClassification,\n", | |
" glue_convert_examples_to_features,\n", | |
" glue_processors\n", | |
")" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "xN3SYuc9TKbv", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Set up the TPU" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "OEbiVyCdTKb2", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 764 | |
}, | |
"outputId": "1e35c687-22cc-44da-8dd9-39d582a19aba" | |
}, | |
"source": [ | |
"try:\n", | |
" tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection\n", | |
" print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])\n", | |
"except ValueError:\n", | |
" tpu = None\n", | |
"strategy = tf.distribute.get_strategy()\n", | |
"if tpu:\n", | |
" tf.config.experimental_connect_to_cluster(tpu)\n", | |
" tf.tpu.experimental.initialize_tpu_system(tpu)\n", | |
" strategy = tf.distribute.experimental.TPUStrategy(tpu)\n", | |
"print(\"REPLICAS: \", strategy.num_replicas_in_sync)" | |
], | |
"execution_count": 3, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:absl:Entering into master device scope: /job:worker/replica:0/task:0/device:CPU:0\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Running on TPU ['10.27.197.186:8470']\n", | |
"INFO:tensorflow:Initializing the TPU system: 10.27.197.186:8470\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Initializing the TPU system: 10.27.197.186:8470\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Clearing out eager caches\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Clearing out eager caches\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Finished initializing TPU system.\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Finished initializing TPU system.\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Found TPU system:\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:Found TPU system:\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Num TPU Cores: 8\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Num TPU Cores: 8\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Num TPU Workers: 1\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Num TPU Workers: 1\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n" | |
], | |
"name": "stdout" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n" | |
], | |
"name": "stderr" | |
}, | |
{ | |
"output_type": "stream", | |
"text": [ | |
"REPLICAS: 8\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "43e1NIDDTKb8", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 34 | |
}, | |
"outputId": "96d472ad-9e32-4acf-eba7-8e3f08a424a8" | |
}, | |
"source": [ | |
"if strategy.num_replicas_in_sync == 8: # single TPU\n", | |
" BATCH_SIZE = 16 * strategy.num_replicas_in_sync\n", | |
" EVAL_BATCH_SIZE = BATCH_SIZE * 4\n", | |
" EPOCHS = 3\n", | |
"else:\n", | |
" BATCH_SIZE = 32\n", | |
" EVAL_BATCH_SIZE = BATCH_SIZE * 2\n", | |
" EPOCHS = 3 \n", | |
"\n", | |
"TASK = \"mrpc\"\n", | |
"if TASK == \"sst-2\":\n", | |
" TFDS_TASK = \"sst2\"\n", | |
"elif TASK == \"sts-b\":\n", | |
" TFDS_TASK = \"stsb\"\n", | |
"else:\n", | |
" TFDS_TASK = TASK\n", | |
" \n", | |
"num_labels = len(glue_processors[TASK]().get_labels())\n", | |
"print(num_labels)" | |
], | |
"execution_count": 4, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"2\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "fObqeVjBTKcB", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"## Dataset" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "BwJmW3RkTKcB", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"Reference: [tensorflow/models/official/nlp/bert](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_classifier.py)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "5nfNw3uzTKcE", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"GLUE_DIR = \"gs://cloud-tpu-checkpoints/bert/classification\"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "uO5kQ0rlTKcI", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def single_file_dataset(input_file, name_to_features):\n", | |
" \"\"\"Creates a single-file dataset to be passed for BERT custom training.\"\"\"\n", | |
" # For training, we want a lot of parallel reading and shuffling.\n", | |
" # For eval, we want no shuffling and parallel reading doesn't matter.\n", | |
" d = tf.data.TFRecordDataset(input_file)\n", | |
" d = d.map(lambda record: decode_record(record, name_to_features))\n", | |
"\n", | |
" # When `input_file` is a path to a single file or a list\n", | |
" # containing a single path, disable auto sharding so that\n", | |
" # same input file is sent to all workers.\n", | |
" if isinstance(input_file, str) or len(input_file) == 1:\n", | |
" options = tf.data.Options()\n", | |
" options.experimental_distribute.auto_shard_policy = (\n", | |
" tf.data.experimental.AutoShardPolicy.OFF)\n", | |
" d = d.with_options(options)\n", | |
" return d" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "kOb5O4rETKcL", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def decode_record(record, name_to_features):\n", | |
" \"\"\"Decodes a record to a TensorFlow example.\"\"\"\n", | |
" example = tf.io.parse_single_example(record, name_to_features)\n", | |
"\n", | |
" # tf.Example only supports tf.int64, but the TPU only supports tf.int32.\n", | |
" # So cast all int64 to int32.\n", | |
" for name in list(example.keys()):\n", | |
" t = example[name]\n", | |
" if t.dtype == tf.int64:\n", | |
" t = tf.cast(t, tf.int32)\n", | |
" example[name] = t\n", | |
"\n", | |
" return example" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "f4-FAqH_TKcQ", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def create_classifier_dataset(file_path,\n", | |
" seq_length,\n", | |
" batch_size,\n", | |
" is_training=True,\n", | |
" input_pipeline_context=None):\n", | |
" \"\"\"Creates input dataset from (tf)records files for train/eval.\"\"\"\n", | |
" name_to_features = {\n", | |
" 'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),\n", | |
" 'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),\n", | |
" 'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),\n", | |
" 'label_ids': tf.io.FixedLenFeature([], tf.int64),\n", | |
" 'is_real_example': tf.io.FixedLenFeature([], tf.int64),\n", | |
" }\n", | |
" dataset = single_file_dataset(file_path, name_to_features)\n", | |
"\n", | |
" # The dataset is always sharded by number of hosts.\n", | |
" # num_input_pipelines is the number of hosts rather than number of cores.\n", | |
" if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:\n", | |
" dataset = dataset.shard(input_pipeline_context.num_input_pipelines,\n", | |
" input_pipeline_context.input_pipeline_id)\n", | |
"\n", | |
" def _select_data_from_record(record):\n", | |
" x = {\n", | |
" 'input_ids': record['input_ids'],\n", | |
" 'attention_mask': record['input_mask'],\n", | |
" 'token_type_ids': record['segment_ids']\n", | |
" }\n", | |
" y = record['label_ids']\n", | |
" return (x, y)\n", | |
"\n", | |
" dataset = dataset.map(_select_data_from_record)\n", | |
"\n", | |
" if is_training:\n", | |
" dataset = dataset.shuffle(100)\n", | |
" dataset = dataset.repeat()\n", | |
"\n", | |
" dataset = dataset.batch(batch_size, drop_remainder=is_training)\n", | |
" dataset = dataset.prefetch(1024)\n", | |
" return dataset" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "2TuCxgcyTKcT", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,\n", | |
" is_training):\n", | |
" \"\"\"Gets a closure to create a dataset.\"\"\"\n", | |
"\n", | |
" def _dataset_fn(ctx=None):\n", | |
" \"\"\"Returns tf.data.Dataset for distributed BERT pretraining.\"\"\"\n", | |
" batch_size = ctx.get_per_replica_batch_size(\n", | |
" global_batch_size) if ctx else global_batch_size\n", | |
" dataset = create_classifier_dataset(\n", | |
" input_file_pattern,\n", | |
" max_seq_length,\n", | |
" batch_size,\n", | |
" is_training=is_training,\n", | |
" input_pipeline_context=ctx)\n", | |
" return dataset\n", | |
"\n", | |
" return _dataset_fn" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "AC6wB97dTKcX", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"with tf.io.gfile.GFile(f'{GLUE_DIR}/{TASK}_meta_data', 'rb') as reader:\n", | |
" input_meta_data = json.loads(reader.read().decode('utf-8'))\n", | |
"\n", | |
"max_seq_length = input_meta_data['max_seq_length']\n", | |
"train_input_fn = get_dataset_fn(\n", | |
" f\"{GLUE_DIR}/{TASK}_train.tf_record\",\n", | |
" max_seq_length,\n", | |
" BATCH_SIZE,\n", | |
" is_training=True)\n", | |
"eval_input_fn = get_dataset_fn(\n", | |
" f\"{GLUE_DIR}/{TASK}_eval.tf_record\",\n", | |
" max_seq_length,\n", | |
" EVAL_BATCH_SIZE,\n", | |
" is_training=False)" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "XAdN8D8TTKcc", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# Load tokenizer and model from pretrained model/vocabulary. Specify the number of labels to classify (2+: classification, 1: regression)\n", | |
"config = BertConfig.from_pretrained(\"bert-base-cased\", num_labels=num_labels)\n", | |
"tokenizer = BertTokenizer.from_pretrained(\"bert-base-cased\")\n", | |
"with strategy.scope():\n", | |
" training_dataset = train_input_fn()\n", | |
" evaluation_dataset = eval_input_fn()\n", | |
" \n", | |
" model = TFBertForSequenceClassification.from_pretrained(\"bert-base-cased\", config=config)\n", | |
" # Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule\n", | |
" opt = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)\n", | |
"\n", | |
" if num_labels == 1:\n", | |
" loss = tf.keras.losses.MeanSquaredError()\n", | |
" else:\n", | |
" loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n", | |
"\n", | |
" metric = tf.keras.metrics.SparseCategoricalAccuracy(\"accuracy\")\n", | |
" model.compile(optimizer=opt, loss=loss, metrics=[metric])" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "CgVAFq0dTKcf", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"train_data_size = input_meta_data['train_data_size']\n", | |
"steps_per_epoch = int(train_data_size / BATCH_SIZE)\n", | |
"eval_steps = int(math.ceil(input_meta_data['eval_data_size'] / EVAL_BATCH_SIZE))" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "prFX26vXTKci", | |
"colab_type": "code", | |
"colab": { | |
"base_uri": "https://localhost:8080/", | |
"height": 139 | |
}, | |
"outputId": "af3b582b-659e-4f22-8c09-a27d2f29608a" | |
}, | |
"source": [ | |
"# Train and evaluate using tf.keras.Model.fit()\n", | |
"# train_steps = train_examples // BATCH_SIZE\n", | |
"# valid_steps = valid_examples // EVAL_BATCH_SIZE\n", | |
"history = model.fit(\n", | |
" training_dataset,\n", | |
" epochs=EPOCHS,\n", | |
" steps_per_epoch=steps_per_epoch,\n", | |
" validation_data=evaluation_dataset,\n", | |
" validation_steps=eval_steps,\n", | |
")" | |
], | |
"execution_count": 13, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"text": [ | |
"Train for 28 steps, validate for 1 steps\n", | |
"Epoch 1/3\n", | |
"28/28 [==============================] - 77s 3s/step - loss: 0.6308 - accuracy: 0.6610 - val_loss: 0.5626 - val_accuracy: 0.6838\n", | |
"Epoch 2/3\n", | |
"28/28 [==============================] - 5s 173ms/step - loss: 0.5520 - accuracy: 0.6936 - val_loss: 0.4901 - val_accuracy: 0.7451\n", | |
"Epoch 3/3\n", | |
"28/28 [==============================] - 5s 171ms/step - loss: 0.4735 - accuracy: 0.7673 - val_loss: 0.4695 - val_accuracy: 0.7696\n" | |
], | |
"name": "stdout" | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "2xLtE34nTKcl", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"# Save TF2 model\n", | |
"os.makedirs(\"./save/\", exist_ok=True)\n", | |
"model.save_pretrained(\"./save/\")" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
}, | |
{ | |
"cell_type": "code", | |
"metadata": { | |
"id": "LVTbpe3KTKco", | |
"colab_type": "code", | |
"colab": {} | |
}, | |
"source": [ | |
"" | |
], | |
"execution_count": 0, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment