Skip to content

Instantly share code, notes, and snippets.

@ceshine
Last active February 23, 2020 18:50
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ceshine/1fe607938c4caef863dab056b0c0048b to your computer and use it in GitHub Desktop.
Save ceshine/1fe607938c4caef863dab056b0c0048b to your computer and use it in GitHub Desktop.
Train huggingface/transformers BERT model on Colab CPU with TF 2.1
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.5"
},
"colab": {
"name": "63ddc9198a1976512bef5e1b92610ea3",
"provenance": [],
"include_colab_link": true
},
"accelerator": "TPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/ceshine/1fe607938c4caef863dab056b0c0048b/run_tf_glue.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1m94kz61TKbh",
"colab_type": "text"
},
"source": [
"Adapted from [transformers/examples/run_tf_glue.py](https://github.com/huggingface/transformers/blob/master/examples/run_tf_glue.py)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "rBTOd3i6TYb_",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 607
},
"outputId": "6c33758d-9a07-4149-ea71-950888ff6a33"
},
"source": [
"%tensorflow_version 2.x\n",
"!pip install transformers==2.3.0"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"text": [
"TensorFlow 2.x selected.\n",
"Collecting transformers==2.3.0\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/50/10/aeefced99c8a59d828a92cc11d213e2743212d3641c87c82d61b035a7d5c/transformers-2.3.0-py3-none-any.whl (447kB)\n",
"\u001b[K |████████████████████████████████| 450kB 3.5MB/s \n",
"\u001b[?25hRequirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from transformers==2.3.0) (1.10.47)\n",
"Collecting sentencepiece\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/74/f4/2d5214cbf13d06e7cb2c20d84115ca25b53ea76fa1f0ade0e3c9749de214/sentencepiece-0.1.85-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)\n",
"\u001b[K |████████████████████████████████| 1.0MB 41.0MB/s \n",
"\u001b[?25hRequirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers==2.3.0) (2019.12.20)\n",
"Requirement already satisfied: requests in /tensorflow-2.1.0/python3.6 (from transformers==2.3.0) (2.22.0)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from transformers==2.3.0) (4.28.1)\n",
"Requirement already satisfied: numpy in /tensorflow-2.1.0/python3.6 (from transformers==2.3.0) (1.18.1)\n",
"Collecting sacremoses\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/a6/b4/7a41d630547a4afd58143597d5a49e07bfd4c42914d8335b2a5657efc14b/sacremoses-0.0.38.tar.gz (860kB)\n",
"\u001b[K |████████████████████████████████| 870kB 23.5MB/s \n",
"\u001b[?25hRequirement already satisfied: botocore<1.14.0,>=1.13.47 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers==2.3.0) (1.13.47)\n",
"Requirement already satisfied: s3transfer<0.3.0,>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers==2.3.0) (0.2.1)\n",
"Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers==2.3.0) (0.9.4)\n",
"Requirement already satisfied: idna<2.9,>=2.5 in /tensorflow-2.1.0/python3.6 (from requests->transformers==2.3.0) (2.8)\n",
"Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /tensorflow-2.1.0/python3.6 (from requests->transformers==2.3.0) (1.25.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /tensorflow-2.1.0/python3.6 (from requests->transformers==2.3.0) (2019.11.28)\n",
"Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /tensorflow-2.1.0/python3.6 (from requests->transformers==2.3.0) (3.0.4)\n",
"Requirement already satisfied: six in /tensorflow-2.1.0/python3.6 (from sacremoses->transformers==2.3.0) (1.13.0)\n",
"Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers==2.3.0) (7.0)\n",
"Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers==2.3.0) (0.14.1)\n",
"Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.14.0,>=1.13.47->boto3->transformers==2.3.0) (0.15.2)\n",
"Requirement already satisfied: python-dateutil<3.0.0,>=2.1; python_version >= \"2.7\" in /usr/local/lib/python3.6/dist-packages (from botocore<1.14.0,>=1.13.47->boto3->transformers==2.3.0) (2.6.1)\n",
"Building wheels for collected packages: sacremoses\n",
" Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for sacremoses: filename=sacremoses-0.0.38-cp36-none-any.whl size=884629 sha256=c44e1b6a02cb9b4c392345d16020e59b7dcabd88422ff5e32b2ca97e15b59ba5\n",
" Stored in directory: /root/.cache/pip/wheels/6d/ec/1a/21b8912e35e02741306f35f66c785f3afe94de754a0eaf1422\n",
"Successfully built sacremoses\n",
"Installing collected packages: sentencepiece, sacremoses, transformers\n",
"Successfully installed sacremoses-0.0.38 sentencepiece-0.1.85 transformers-2.3.0\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "g5gigYg_TKbl",
"colab_type": "code",
"colab": {}
},
"source": [
"import os\n",
"import json\n",
"import math\n",
"\n",
"import tensorflow as tf\n",
"import tensorflow_datasets\n",
"\n",
"from transformers import (\n",
" BertConfig,\n",
" BertTokenizer,\n",
" TFBertForSequenceClassification,\n",
" glue_convert_examples_to_features,\n",
" glue_processors\n",
")"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "xN3SYuc9TKbv",
"colab_type": "text"
},
"source": [
"## Set up the TPU"
]
},
{
"cell_type": "code",
"metadata": {
"id": "OEbiVyCdTKb2",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 764
},
"outputId": "1e35c687-22cc-44da-8dd9-39d582a19aba"
},
"source": [
"try:\n",
" tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection\n",
" print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])\n",
"except ValueError:\n",
" tpu = None\n",
"strategy = tf.distribute.get_strategy()\n",
"if tpu:\n",
" tf.config.experimental_connect_to_cluster(tpu)\n",
" tf.tpu.experimental.initialize_tpu_system(tpu)\n",
" strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
"print(\"REPLICAS: \", strategy.num_replicas_in_sync)"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"text": [
"INFO:absl:Entering into master device scope: /job:worker/replica:0/task:0/device:CPU:0\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"Running on TPU ['10.27.197.186:8470']\n",
"INFO:tensorflow:Initializing the TPU system: 10.27.197.186:8470\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Initializing the TPU system: 10.27.197.186:8470\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Clearing out eager caches\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Clearing out eager caches\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Finished initializing TPU system.\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Finished initializing TPU system.\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Found TPU system:\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:Found TPU system:\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Num TPU Cores: 8\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Num TPU Cores: 8\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Num TPU Workers: 1\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Num TPU Workers: 1\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
],
"name": "stdout"
},
{
"output_type": "stream",
"text": [
"INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
],
"name": "stderr"
},
{
"output_type": "stream",
"text": [
"REPLICAS: 8\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "43e1NIDDTKb8",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 34
},
"outputId": "96d472ad-9e32-4acf-eba7-8e3f08a424a8"
},
"source": [
"if strategy.num_replicas_in_sync == 8: # single TPU\n",
" BATCH_SIZE = 16 * strategy.num_replicas_in_sync\n",
" EVAL_BATCH_SIZE = BATCH_SIZE * 4\n",
" EPOCHS = 3\n",
"else:\n",
" BATCH_SIZE = 32\n",
" EVAL_BATCH_SIZE = BATCH_SIZE * 2\n",
" EPOCHS = 3 \n",
"\n",
"TASK = \"mrpc\"\n",
"if TASK == \"sst-2\":\n",
" TFDS_TASK = \"sst2\"\n",
"elif TASK == \"sts-b\":\n",
" TFDS_TASK = \"stsb\"\n",
"else:\n",
" TFDS_TASK = TASK\n",
" \n",
"num_labels = len(glue_processors[TASK]().get_labels())\n",
"print(num_labels)"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"text": [
"2\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fObqeVjBTKcB",
"colab_type": "text"
},
"source": [
"## Dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BwJmW3RkTKcB",
"colab_type": "text"
},
"source": [
"Reference: [tensorflow/models/official/nlp/bert](https://github.com/tensorflow/models/blob/master/official/nlp/bert/run_classifier.py)"
]
},
{
"cell_type": "code",
"metadata": {
"id": "5nfNw3uzTKcE",
"colab_type": "code",
"colab": {}
},
"source": [
"GLUE_DIR = \"gs://cloud-tpu-checkpoints/bert/classification\""
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "uO5kQ0rlTKcI",
"colab_type": "code",
"colab": {}
},
"source": [
"def single_file_dataset(input_file, name_to_features):\n",
" \"\"\"Creates a single-file dataset to be passed for BERT custom training.\"\"\"\n",
" # For training, we want a lot of parallel reading and shuffling.\n",
" # For eval, we want no shuffling and parallel reading doesn't matter.\n",
" d = tf.data.TFRecordDataset(input_file)\n",
" d = d.map(lambda record: decode_record(record, name_to_features))\n",
"\n",
" # When `input_file` is a path to a single file or a list\n",
" # containing a single path, disable auto sharding so that\n",
" # same input file is sent to all workers.\n",
" if isinstance(input_file, str) or len(input_file) == 1:\n",
" options = tf.data.Options()\n",
" options.experimental_distribute.auto_shard_policy = (\n",
" tf.data.experimental.AutoShardPolicy.OFF)\n",
" d = d.with_options(options)\n",
" return d"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "kOb5O4rETKcL",
"colab_type": "code",
"colab": {}
},
"source": [
"def decode_record(record, name_to_features):\n",
" \"\"\"Decodes a record to a TensorFlow example.\"\"\"\n",
" example = tf.io.parse_single_example(record, name_to_features)\n",
"\n",
" # tf.Example only supports tf.int64, but the TPU only supports tf.int32.\n",
" # So cast all int64 to int32.\n",
" for name in list(example.keys()):\n",
" t = example[name]\n",
" if t.dtype == tf.int64:\n",
" t = tf.cast(t, tf.int32)\n",
" example[name] = t\n",
"\n",
" return example"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "f4-FAqH_TKcQ",
"colab_type": "code",
"colab": {}
},
"source": [
"def create_classifier_dataset(file_path,\n",
" seq_length,\n",
" batch_size,\n",
" is_training=True,\n",
" input_pipeline_context=None):\n",
" \"\"\"Creates input dataset from (tf)records files for train/eval.\"\"\"\n",
" name_to_features = {\n",
" 'input_ids': tf.io.FixedLenFeature([seq_length], tf.int64),\n",
" 'input_mask': tf.io.FixedLenFeature([seq_length], tf.int64),\n",
" 'segment_ids': tf.io.FixedLenFeature([seq_length], tf.int64),\n",
" 'label_ids': tf.io.FixedLenFeature([], tf.int64),\n",
" 'is_real_example': tf.io.FixedLenFeature([], tf.int64),\n",
" }\n",
" dataset = single_file_dataset(file_path, name_to_features)\n",
"\n",
" # The dataset is always sharded by number of hosts.\n",
" # num_input_pipelines is the number of hosts rather than number of cores.\n",
" if input_pipeline_context and input_pipeline_context.num_input_pipelines > 1:\n",
" dataset = dataset.shard(input_pipeline_context.num_input_pipelines,\n",
" input_pipeline_context.input_pipeline_id)\n",
"\n",
" def _select_data_from_record(record):\n",
" x = {\n",
" 'input_ids': record['input_ids'],\n",
" 'attention_mask': record['input_mask'],\n",
" 'token_type_ids': record['segment_ids']\n",
" }\n",
" y = record['label_ids']\n",
" return (x, y)\n",
"\n",
" dataset = dataset.map(_select_data_from_record)\n",
"\n",
" if is_training:\n",
" dataset = dataset.shuffle(100)\n",
" dataset = dataset.repeat()\n",
"\n",
" dataset = dataset.batch(batch_size, drop_remainder=is_training)\n",
" dataset = dataset.prefetch(1024)\n",
" return dataset"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "2TuCxgcyTKcT",
"colab_type": "code",
"colab": {}
},
"source": [
"def get_dataset_fn(input_file_pattern, max_seq_length, global_batch_size,\n",
" is_training):\n",
" \"\"\"Gets a closure to create a dataset.\"\"\"\n",
"\n",
" def _dataset_fn(ctx=None):\n",
" \"\"\"Returns tf.data.Dataset for distributed BERT pretraining.\"\"\"\n",
" batch_size = ctx.get_per_replica_batch_size(\n",
" global_batch_size) if ctx else global_batch_size\n",
" dataset = create_classifier_dataset(\n",
" input_file_pattern,\n",
" max_seq_length,\n",
" batch_size,\n",
" is_training=is_training,\n",
" input_pipeline_context=ctx)\n",
" return dataset\n",
"\n",
" return _dataset_fn"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "AC6wB97dTKcX",
"colab_type": "code",
"colab": {}
},
"source": [
"with tf.io.gfile.GFile(f'{GLUE_DIR}/{TASK}_meta_data', 'rb') as reader:\n",
" input_meta_data = json.loads(reader.read().decode('utf-8'))\n",
"\n",
"max_seq_length = input_meta_data['max_seq_length']\n",
"train_input_fn = get_dataset_fn(\n",
" f\"{GLUE_DIR}/{TASK}_train.tf_record\",\n",
" max_seq_length,\n",
" BATCH_SIZE,\n",
" is_training=True)\n",
"eval_input_fn = get_dataset_fn(\n",
" f\"{GLUE_DIR}/{TASK}_eval.tf_record\",\n",
" max_seq_length,\n",
" EVAL_BATCH_SIZE,\n",
" is_training=False)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "XAdN8D8TTKcc",
"colab_type": "code",
"colab": {}
},
"source": [
"# Load tokenizer and model from pretrained model/vocabulary. Specify the number of labels to classify (2+: classification, 1: regression)\n",
"config = BertConfig.from_pretrained(\"bert-base-cased\", num_labels=num_labels)\n",
"tokenizer = BertTokenizer.from_pretrained(\"bert-base-cased\")\n",
"with strategy.scope():\n",
" training_dataset = train_input_fn()\n",
" evaluation_dataset = eval_input_fn()\n",
" \n",
" model = TFBertForSequenceClassification.from_pretrained(\"bert-base-cased\", config=config)\n",
" # Prepare training: Compile tf.keras model with optimizer, loss and learning rate schedule\n",
" opt = tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08)\n",
"\n",
" if num_labels == 1:\n",
" loss = tf.keras.losses.MeanSquaredError()\n",
" else:\n",
" loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
"\n",
" metric = tf.keras.metrics.SparseCategoricalAccuracy(\"accuracy\")\n",
" model.compile(optimizer=opt, loss=loss, metrics=[metric])"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "CgVAFq0dTKcf",
"colab_type": "code",
"colab": {}
},
"source": [
"train_data_size = input_meta_data['train_data_size']\n",
"steps_per_epoch = int(train_data_size / BATCH_SIZE)\n",
"eval_steps = int(math.ceil(input_meta_data['eval_data_size'] / EVAL_BATCH_SIZE))"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "prFX26vXTKci",
"colab_type": "code",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 139
},
"outputId": "af3b582b-659e-4f22-8c09-a27d2f29608a"
},
"source": [
"# Train and evaluate using tf.keras.Model.fit()\n",
"# train_steps = train_examples // BATCH_SIZE\n",
"# valid_steps = valid_examples // EVAL_BATCH_SIZE\n",
"history = model.fit(\n",
" training_dataset,\n",
" epochs=EPOCHS,\n",
" steps_per_epoch=steps_per_epoch,\n",
" validation_data=evaluation_dataset,\n",
" validation_steps=eval_steps,\n",
")"
],
"execution_count": 13,
"outputs": [
{
"output_type": "stream",
"text": [
"Train for 28 steps, validate for 1 steps\n",
"Epoch 1/3\n",
"28/28 [==============================] - 77s 3s/step - loss: 0.6308 - accuracy: 0.6610 - val_loss: 0.5626 - val_accuracy: 0.6838\n",
"Epoch 2/3\n",
"28/28 [==============================] - 5s 173ms/step - loss: 0.5520 - accuracy: 0.6936 - val_loss: 0.4901 - val_accuracy: 0.7451\n",
"Epoch 3/3\n",
"28/28 [==============================] - 5s 171ms/step - loss: 0.4735 - accuracy: 0.7673 - val_loss: 0.4695 - val_accuracy: 0.7696\n"
],
"name": "stdout"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "2xLtE34nTKcl",
"colab_type": "code",
"colab": {}
},
"source": [
"# Save TF2 model\n",
"os.makedirs(\"./save/\", exist_ok=True)\n",
"model.save_pretrained(\"./save/\")"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "LVTbpe3KTKco",
"colab_type": "code",
"colab": {}
},
"source": [
""
],
"execution_count": 0,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment