Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save jayant-yadav/af5d4f3106499868a76f7fd0aa169844 to your computer and use it in GitHub Desktop.
Save jayant-yadav/af5d4f3106499868a76f7fd0aa169844 to your computer and use it in GitHub Desktop.
Fine-Tune-BERT-for-Text-Classification-with-TensorFlow.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/jayant-yadav/af5d4f3106499868a76f7fd0aa169844/fine-tune-bert-for-text-classification-with-tensorflow.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zGCJYkQj_Uu2"
},
"source": [
"<h2 align=center> Fine-Tune BERT for Text Classification with TensorFlow</h2>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4y2m1S6e12il"
},
"source": [
"<div align=\"center\">\n",
" <img width=\"512px\" src='https://drive.google.com/uc?id=1fnJTeJs5HUpz7nix-F9E6EZdgUflqyEu' />\n",
" <p style=\"text-align: center;color:gray\">Figure 1: BERT Classification Model</p>\n",
"</div>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eYYYWqWr_WCC"
},
"source": [
"In this [project](https://www.coursera.org/projects/fine-tune-bert-tensorflow/), you will learn how to fine-tune a BERT model for text classification using TensorFlow and TF-Hub."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5yQG5PCO_WFx"
},
"source": [
"The pretrained BERT model used in this project is [available](https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2) on [TensorFlow Hub](https://tfhub.dev/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7pKNS21u_WJo"
},
"source": [
"### Learning Objectives"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_3NHSMXv_WMv"
},
"source": [
"By the time you complete this project, you will be able to:\n",
"\n",
"- Build TensorFlow Input Pipelines for Text Data with the [`tf.data`](https://www.tensorflow.org/api_docs/python/tf/data) API\n",
"- Tokenize and Preprocess Text for BERT\n",
"- Fine-tune BERT for text classification with TensorFlow 2 and [TF Hub](https://tfhub.dev)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o6BEe-3-AVRQ"
},
"source": [
"### Prerequisites"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Sc9f-8rLAVUS"
},
"source": [
"In order to be successful with this project, it is assumed you are:\n",
"\n",
"- Competent in the Python programming language\n",
"- Familiar with deep learning for Natural Language Processing (NLP)\n",
"- Familiar with TensorFlow, and its Keras API"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "MYXXV5n3Ab-4"
},
"source": [
"### Contents"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XhK-SYGyAjxe"
},
"source": [
"This project/notebook consists of several Tasks.\n",
"\n",
"- **[Task 1]()**: Introduction to the Project.\n",
"- **[Task 2]()**: Setup your TensorFlow and Colab Runtime\n",
"- **[Task 3]()**: Download and Import the Quora Insincere Questions Dataset\n",
"- **[Task 4]()**: Create tf.data.Datasets for Training and Evaluation\n",
"- **[Task 5]()**: Download a Pre-trained BERT Model from TensorFlow Hub\n",
"- **[Task 6]()**: Tokenize and Preprocess Text for BERT\n",
"- **[Task 7]()**: Wrap a Python Function into a TensorFlow op for Eager Execution\n",
"- **[Task 8]()**: Create a TensorFlow Input Pipeline with `tf.data`\n",
"- **[Task 9]()**: Add a Classification Head to the BERT `hub.KerasLayer`\n",
"- **[Task 10]()**: Fine-Tune BERT for Text Classification\n",
"- **[Task 11]()**: Evaluate the BERT Text Classification Model"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IaArqXjRAcBa"
},
"source": [
"## Task 2: Setup your TensorFlow and Colab Runtime."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GDDhjzZ5A4Q_"
},
"source": [
"You will only be able to use the Colab Notebook after you save it to your Google Drive folder. Click on the File menu and select “Save a copy in Drive…\n",
"\n",
"![Copy to Drive](https://drive.google.com/uc?id=1CH3eDmuJL8WR0AP1r3UE6sOPuqq8_Wl7)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mpe6GhLuBJWB"
},
"source": [
"### Check GPU Availability\n",
"\n",
"Check if your Colab notebook is configured to use Graphical Processing Units (GPUs). If zero GPUs are available, check if the Colab notebook is configured to use GPUs (Menu > Runtime > Change Runtime Type).\n",
"\n",
"![Hardware Accelerator Settings](https://drive.google.com/uc?id=1qrihuuMtvzXJHiRV8M7RngbxFYipXKQx)\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "8V9c8vzSL3aj",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "e6559f4d-b833-46f5-f65c-56112e639f1e"
},
"source": [
"!nvidia-smi"
],
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Thu Aug 31 05:58:11 2023 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 525.105.17 Driver Version: 525.105.17 CUDA Version: 12.0 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 35C P8 9W / 70W | 0MiB / 15360MiB | 0% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"| No running processes found |\n",
"+-----------------------------------------------------------------------------+\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Obch3rAuBVf0"
},
"source": [
"### Install TensorFlow and TensorFlow Model Garden"
]
},
{
"cell_type": "code",
"metadata": {
"id": "bUQEY3dFB0jX",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "bcf5f40c-1741-493c-fa69-c01fbe5abaca"
},
"source": [
"import tensorflow as tf\n",
"print(tf.version.VERSION)"
],
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"2.12.0\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "aU3YLZ1TYKUt",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "4097c040-2b92-4f76-ff86-9227a10df653"
},
"source": [
"!pip install -q tensorflow==2.3.0"
],
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[31mERROR: Could not find a version that satisfies the requirement tensorflow==2.3.0 (from versions: 2.8.0rc0, 2.8.0rc1, 2.8.0, 2.8.1, 2.8.2, 2.8.3, 2.8.4, 2.9.0rc0, 2.9.0rc1, 2.9.0rc2, 2.9.0, 2.9.1, 2.9.2, 2.9.3, 2.10.0rc0, 2.10.0rc1, 2.10.0rc2, 2.10.0rc3, 2.10.0, 2.10.1, 2.11.0rc0, 2.11.0rc1, 2.11.0rc2, 2.11.0, 2.11.1, 2.12.0rc0, 2.12.0rc1, 2.12.0, 2.12.1, 2.13.0rc0, 2.13.0rc1, 2.13.0rc2, 2.13.0, 2.14.0rc0)\u001b[0m\u001b[31m\n",
"\u001b[0m\u001b[31mERROR: No matching distribution found for tensorflow==2.3.0\u001b[0m\u001b[31m\n",
"\u001b[0m"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "AFRTC-zwUy6D",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "dff990a1-35c6-4e21-dd77-1ded9f9366de"
},
"source": [
"!git clone --depth 1 -b v2.3.0 https://github.com/tensorflow/models.git"
],
"execution_count": 4,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"fatal: destination path 'models' already exists and is not an empty directory.\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "3H2G0571zLLs",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "416bd7c1-d0ac-4059-e9df-ebd8b63eca07"
},
"source": [
"# install requirements to use tensorflow/models repository\n",
"!pip install -Uqr models/official/requirements.txt\n",
"# you may have to restart the runtime afterwards"
],
"execution_count": 5,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GVjksk4yCXur"
},
"source": [
"## Restart the Runtime\n",
"\n",
"**Note**\n",
"After installing the required Python packages, you'll need to restart the Colab Runtime Engine (Menu > Runtime > Restart runtime...)\n",
"\n",
"![Restart of the Colab Runtime Engine](https://drive.google.com/uc?id=1xnjAy2sxIymKhydkqb0RKzgVK9rh3teH)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "IMsEoT3Fg4Wg"
},
"source": [
"## Task 3: Download and Import the Quora Insincere Questions Dataset"
]
},
{
"cell_type": "code",
"metadata": {
"id": "GmqEylyFYTdP",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "7059df3f-a28b-40b2-cef7-c50852c201ca"
},
"source": [
"import numpy as np\n",
"import tensorflow as tf\n",
"import tensorflow_hub as hub\n",
"import sys\n",
"sys.path.append('models')\n",
"from official.nlp.data import classifier_data_lib\n",
"from official.nlp.bert import tokenization\n",
"from official.nlp import optimization"
],
"execution_count": 6,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/tensorflow_addons/utils/tfa_eol_msg.py:23: UserWarning: \n",
"\n",
"TensorFlow Addons (TFA) has ended development and introduction of new features.\n",
"TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.\n",
"Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). \n",
"\n",
"For more information see: https://github.com/tensorflow/addons/issues/2807 \n",
"\n",
" warnings.warn(\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "ZuX1lB8pPJ-W",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "b81869b8-3bcb-4fcb-d571-89c1d488ec53"
},
"source": [
"print(\"TF Version: \", tf.__version__)\n",
"print(\"Eager mode: \", tf.executing_eagerly())\n",
"print(\"Hub version: \", hub.__version__)\n",
"print(\"GPU is\", \"available\" if tf.config.experimental.list_physical_devices(\"GPU\") else \"NOT AVAILABLE\")"
],
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"TF Version: 2.12.0\n",
"Eager mode: True\n",
"Hub version: 0.14.0\n",
"GPU is available\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QtbwpWgyEZg7"
},
"source": [
"A downloadable copy of the [Quora Insincere Questions Classification data](https://www.kaggle.com/c/quora-insincere-questions-classification/data) can be found [https://archive.org/download/fine-tune-bert-tensorflow-train.csv/train.csv.zip](https://archive.org/download/fine-tune-bert-tensorflow-train.csv/train.csv.zip). Decompress and read the data into a pandas DataFrame."
]
},
{
"cell_type": "code",
"metadata": {
"id": "0nI-9itVwCCQ",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "4c87512e-50f8-4948-c6f2-5c582d2f79dd"
},
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"df = pd.read_csv('https://archive.org/download/fine-tune-bert-tensorflow-train.csv/train.csv.zip',compression='zip', low_memory=False)\n",
"df.shape"
],
"execution_count": 8,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"(1306122, 3)"
]
},
"metadata": {},
"execution_count": 8
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "yeHE98KiMvDd",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 363
},
"outputId": "0884be22-9281-4728-93ac-723f3b541374"
},
"source": [
"df.tail(10)"
],
"execution_count": 9,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
" qid \\\n",
"1306112 ffffa5b0fa76431c063f \n",
"1306113 ffffae5dbda3dc9e9771 \n",
"1306114 ffffba7c4888798571c1 \n",
"1306115 ffffc0c7158658a06fd9 \n",
"1306116 ffffc404da586ac5a08f \n",
"1306117 ffffcc4e2331aaf1e41e \n",
"1306118 ffffd431801e5a2f4861 \n",
"1306119 ffffd48fb36b63db010c \n",
"1306120 ffffec519fa37cf60c78 \n",
"1306121 ffffed09fedb5088744a \n",
"\n",
" question_text target \n",
"1306112 Are you ashamed of being an Indian? 1 \n",
"1306113 What are the methods to determine fossil ages ... 0 \n",
"1306114 What is your story today? 0 \n",
"1306115 How do I consume 150 gms protein daily both ve... 0 \n",
"1306116 What are the good career options for a msc che... 0 \n",
"1306117 What other technical skills do you need as a c... 0 \n",
"1306118 Does MS in ECE have good job prospects in USA ... 0 \n",
"1306119 Is foam insulation toxic? 0 \n",
"1306120 How can one start a research project based on ... 0 \n",
"1306121 Who wins in a battle between a Wolverine and a... 0 "
],
"text/html": [
"\n",
" <div id=\"df-86e00451-900d-46e8-8d76-7e4bf7293772\" class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>qid</th>\n",
" <th>question_text</th>\n",
" <th>target</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1306112</th>\n",
" <td>ffffa5b0fa76431c063f</td>\n",
" <td>Are you ashamed of being an Indian?</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1306113</th>\n",
" <td>ffffae5dbda3dc9e9771</td>\n",
" <td>What are the methods to determine fossil ages ...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1306114</th>\n",
" <td>ffffba7c4888798571c1</td>\n",
" <td>What is your story today?</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1306115</th>\n",
" <td>ffffc0c7158658a06fd9</td>\n",
" <td>How do I consume 150 gms protein daily both ve...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1306116</th>\n",
" <td>ffffc404da586ac5a08f</td>\n",
" <td>What are the good career options for a msc che...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1306117</th>\n",
" <td>ffffcc4e2331aaf1e41e</td>\n",
" <td>What other technical skills do you need as a c...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1306118</th>\n",
" <td>ffffd431801e5a2f4861</td>\n",
" <td>Does MS in ECE have good job prospects in USA ...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1306119</th>\n",
" <td>ffffd48fb36b63db010c</td>\n",
" <td>Is foam insulation toxic?</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1306120</th>\n",
" <td>ffffec519fa37cf60c78</td>\n",
" <td>How can one start a research project based on ...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1306121</th>\n",
" <td>ffffed09fedb5088744a</td>\n",
" <td>Who wins in a battle between a Wolverine and a...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <div class=\"colab-df-buttons\">\n",
"\n",
" <div class=\"colab-df-container\">\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-86e00451-900d-46e8-8d76-7e4bf7293772')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
"\n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
" <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
" </svg>\n",
" </button>\n",
"\n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" .colab-df-buttons div {\n",
" margin-bottom: 4px;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-86e00451-900d-46e8-8d76-7e4bf7293772 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-86e00451-900d-46e8-8d76-7e4bf7293772');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
"\n",
"\n",
"<div id=\"df-dc05f8da-6a68-4281-bdfe-accb9daeed14\">\n",
" <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-dc05f8da-6a68-4281-bdfe-accb9daeed14')\"\n",
" title=\"Suggest charts.\"\n",
" style=\"display:none;\">\n",
"\n",
"<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <g>\n",
" <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
" </g>\n",
"</svg>\n",
" </button>\n",
"\n",
"<style>\n",
" .colab-df-quickchart {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-quickchart:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-quickchart {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-quickchart:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
"</style>\n",
"\n",
" <script>\n",
" async function quickchart(key) {\n",
" const charts = await google.colab.kernel.invokeFunction(\n",
" 'suggestCharts', [key], {});\n",
" }\n",
" (() => {\n",
" let quickchartButtonEl =\n",
" document.querySelector('#df-dc05f8da-6a68-4281-bdfe-accb9daeed14 button');\n",
" quickchartButtonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
" })();\n",
" </script>\n",
"</div>\n",
" </div>\n",
" </div>\n"
]
},
"metadata": {},
"execution_count": 9
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "leRFRWJMocVa",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 470
},
"outputId": "148ad642-2c90-4ff8-930a-5ec10e485b94"
},
"source": [
"df.target.plot(kind='hist', title='target distribution')\n",
"#since this is highly skewed, we will make stratified splits."
],
"execution_count": 10,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<Axes: title={'center': 'target distribution'}, ylabel='Frequency'>"
]
},
"metadata": {},
"execution_count": 10
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ELjswHcFHfp3"
},
"source": [
"## Task 4: Create tf.data.Datasets for Training and Evaluation"
]
},
{
"cell_type": "code",
"metadata": {
"id": "fScULIGPwuWk",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "eeca6444-6e19-4105-9643-4aab252bd2be"
},
"source": [
"train_df, remaining_df = train_test_split(df, random_state=42,train_size = 0.0075, stratify=df['target'].values)\n",
"valid_df, _ = train_test_split(remaining_df, random_state=42, train_size = 0.00075, stratify=remaining_df['target'].values)\n",
"train_df.shape, valid_df.shape\n",
"# BERT takes time to train and perform inference, because of the following reasons:\n",
"# 1. 340M params.\n",
"# 2. I/O bottleneck\n",
"# using tf.data to prepare input pipelines is beneficial to reduce I/O bottlenecks."
],
"execution_count": 11,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"((9795, 3), (972, 3))"
]
},
"metadata": {},
"execution_count": 11
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "qQYMGT5_qLPX",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "ef3840fb-dc30-4051-b3f1-383ba56fb1ef"
},
"source": [
"# we transfer data to cpu using tf.data.Dataset, which gives a py iter.\n",
"with tf.device('/cpu:0'): # was it in GPU by default?\n",
" train_data = tf.data.Dataset.from_tensor_slices((train_df['question_text'].values, train_df['target'].values))\n",
" valid_data = tf.data.Dataset.from_tensor_slices((valid_df['question_text'].values, valid_df['target'].values))\n",
"\n",
" for text, label in train_data.take(1):\n",
" print(text, label)"
],
"execution_count": 12,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"tf.Tensor(b'Why are unhealthy relationships so desirable?', shape=(), dtype=string) tf.Tensor(0, shape=(), dtype=int64)\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "e2-ReN88Hvy_"
},
"source": [
"## Task 5: Download a Pre-trained BERT Model from TensorFlow Hub"
]
},
{
"cell_type": "code",
"metadata": {
"id": "EMb5M86b4-BU"
},
"source": [
"\"\"\"\n",
"Each line of the dataset is composed of the review text and its label\n",
"- Data preprocessing consists of transforming text to BERT input features:\n",
"input_word_ids, input_mask, segment_ids\n",
"- In the process, tokenizing the text is done with the provided BERT model tokenizer\n",
"\"\"\"\n",
"\n",
"label_list = [0,1] # Label categories\n",
"max_seq_length = 128 # maximum length of (token) input sequences\n",
"batch_size = 32\n",
"\n",
"\n",
"# Get BERT layer and tokenizer:\n",
"# More details here: https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2\n",
"encoder = hub.KerasLayer(\n",
" \"https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/4\",\n",
" trainable=True) # bert_layer in v2\n",
"\n",
"vocab_file = encoder.resolved_object.vocab_file.asset_path.numpy()\n",
"do_lower_case = encoder.resolved_object.do_lower_case.numpy() # since model is uncased\n",
"tokenizer = tokenization.FullTokenizer(vocab_file, do_lower_case)"
],
"execution_count": 13,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "wEUezMK-zkkI",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "a8df2bc5-a78c-486f-8607-92baecca157d"
},
"source": [
"tokenizer.wordpiece_tokenizer.tokenize('hi, how are you doing?')"
],
"execution_count": 14,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"['hi', '##,', 'how', 'are', 'you', 'doing', '##?']"
]
},
"metadata": {},
"execution_count": 14
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "5AFsmTO5JSmc",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "d92d1e7b-a63e-4480-a184-04209a7d2736"
},
"source": [
"tokenizer.convert_tokens_to_ids(tokenizer.wordpiece_tokenizer.tokenize('hi, how are you doing?'))"
],
"execution_count": 15,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"[7632, 29623, 2129, 2024, 2017, 2725, 29632]"
]
},
"metadata": {},
"execution_count": 15
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9QinzNq6OsP1"
},
"source": [
"## Task 6: Tokenize and Preprocess Text for BERT"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3FTqJ698zZ1e"
},
"source": [
"<div align=\"center\">\n",
"\n",
"---\n",
"\n",
"\n",
" <img width=\"512px\" src='https://drive.google.com/uc?id=1-SpKFELnEvBMBqO7h3iypo8q9uUUo96P' />\n",
" <p style=\"text-align: center;color:gray\">Figure 2: BERT Tokenizer</p>\n",
"</div>"
]
},
{
"cell_type": "markdown",
"source": [
"**Token ids:** If token exists in vocab, token id is returned from it. else new token id is generated. We also append padding tokens if text is smaller than 128 and also add SEP and CLS and their corresponding tokenids. \n",
"**Input Mask:** To attend the main text and mask out the padding tokens. \n",
"**Input type ids:** In Next Sentence Prediction (NSP) task, BERT will give 0 to first sentence and 1 to second sentence and perform NSP. But since we are only dealing with 1 sentence at a time and do not have 2nd sentence, we will be providing all the values as 0s."
],
"metadata": {
"id": "xv07eXTe6I8o"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "cWYkggYe6HZc"
},
"source": [
"We'll need to transform our data into a format BERT understands. This involves two steps. First, we create InputExamples using `classifier_data_lib`'s constructor `InputExample` provided in the BERT library."
]
},
{
"cell_type": "code",
"metadata": {
"id": "m-21A5aNJM0W"
},
"source": [
"# This provides a function to convert row to input features and label\n",
"\n",
"def to_feature(text, label, label_list=label_list, max_seq_length=max_seq_length, tokenizer=tokenizer):\n",
" example = classifier_data_lib.InputExample(guid = None, # since we do not have 2 sentences in NSP\n",
" text_a = text.numpy(),\n",
" text_b = None, # since we do not have 2nd sentence in NSP\n",
" label = label.numpy()\n",
" )\n",
" feature = classifier_data_lib.convert_single_example(0,example,label_list, max_seq_length, tokenizer)\n",
"\n",
" return (feature.input_ids, feature.input_mask, feature.segment_ids, feature.label_id)\n"
],
"execution_count": 16,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "A_HQSsHwWCsK"
},
"source": [
"You want to use [`Dataset.map`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map) to apply this function to each element of the dataset. [`Dataset.map`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map) runs in graph mode.\n",
"\n",
"- Graph tensors do not have a value.\n",
"- In graph mode you can only use TensorFlow Ops and functions.\n",
"\n",
"So you can't `.map` this function directly: You need to wrap it in a [`tf.py_function`](https://www.tensorflow.org/api_docs/python/tf/py_function). The [`tf.py_function`](https://www.tensorflow.org/api_docs/python/tf/py_function) will pass regular tensors (with a value and a `.numpy()` method to access it), to the wrapped python function."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zaNlkKVfWX0Q"
},
"source": [
"## Task 7: Wrap a Python Function into a TensorFlow op for Eager Execution"
]
},
{
"cell_type": "code",
"metadata": {
"id": "AGACBcfCWC2O"
},
"source": [
"def to_feature_map(text, label):\n",
" input_ids, input_mask, segment_ids, label_id = tf.py_function(to_feature, inp = [text, label],\n",
" Tout = [tf.int32, tf.int32, tf.int32, tf.int32])\n",
" # py_function does not shape it, so we do it explicitly\n",
" input_ids.set_shape([max_seq_length])\n",
" input_mask.set_shape([max_seq_length])\n",
" segment_ids.set_shape([max_seq_length])\n",
" label_id.set_shape([]) # no need to shape it\n",
"\n",
" x = {\n",
" 'input_word_ids': input_ids,\n",
" 'input_mask': input_mask,\n",
" 'input_type_ids': segment_ids\n",
" }\n",
"\n",
" return (x,label_id)\n"
],
"execution_count": 17,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "dhdO6MjTbtn1"
},
"source": [
"## Task 8: Create a TensorFlow Input Pipeline with `tf.data`"
]
},
{
"cell_type": "code",
"metadata": {
"id": "LHRdiO3dnPNr"
},
"source": [
"with tf.device('/cpu:0'):\n",
" # train\n",
" train_data = (train_data.map(to_feature_map,\n",
" num_parallel_calls = tf.data.experimental.AUTOTUNE) # to run parallely and letting tf decide\n",
" .shuffle(1000)\n",
" .batch(32, drop_remainder = True)\n",
" .prefetch(tf.data.experimental.AUTOTUNE)) # for fetching next training data while it is training prev. Data\n",
"\n",
"\n",
"\n",
" # valid\n",
"valid_data = (valid_data.map(to_feature_map,\n",
" num_parallel_calls = tf.data.experimental.AUTOTUNE)\n",
" .batch(32, drop_remainder = True)\n",
" .prefetch(tf.data.experimental.AUTOTUNE))\n"
],
"execution_count": 18,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "KLUWnfx-YDi2"
},
"source": [
"The resulting `tf.data.Datasets` return `(features, labels)` pairs, as expected by [`keras.Model.fit`](https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit):"
]
},
{
"cell_type": "code",
"metadata": {
"id": "B0Z2cy9GHQ8x",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "a00d2f22-4933-46ee-afe6-0d83eb7d46db"
},
"source": [
"# train data spec\n",
"train_data.element_spec"
],
"execution_count": 19,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"({'input_word_ids': TensorSpec(shape=(32, 128), dtype=tf.int32, name=None),\n",
" 'input_mask': TensorSpec(shape=(32, 128), dtype=tf.int32, name=None),\n",
" 'input_type_ids': TensorSpec(shape=(32, 128), dtype=tf.int32, name=None)},\n",
" TensorSpec(shape=(32,), dtype=tf.int32, name=None))"
]
},
"metadata": {},
"execution_count": 19
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "DGAH-ycYOmao",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "c2d79741-9620-4626-c4af-44b39926c514"
},
"source": [
"# valid data spec\n",
"valid_data.element_spec"
],
"execution_count": 20,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"({'input_word_ids': TensorSpec(shape=(32, 128), dtype=tf.int32, name=None),\n",
" 'input_mask': TensorSpec(shape=(32, 128), dtype=tf.int32, name=None),\n",
" 'input_type_ids': TensorSpec(shape=(32, 128), dtype=tf.int32, name=None)},\n",
" TensorSpec(shape=(32,), dtype=tf.int32, name=None))"
]
},
"metadata": {},
"execution_count": 20
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GZxe-7yhPyQe"
},
"source": [
"## Task 9: Add a Classification Head to the BERT Layer"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9THH5V0Dw2HO"
},
"source": [
"<div align=\"center\">\n",
" <img width=\"512px\" src='https://drive.google.com/uc?id=1fnJTeJs5HUpz7nix-F9E6EZdgUflqyEu' />\n",
" <p style=\"text-align: center;color:gray\">Figure 3: BERT Layer</p>\n",
"</div>"
]
},
{
"cell_type": "code",
"metadata": {
"id": "G9il4gtlADcp"
},
"source": [
"# Building the model\n",
"def create_model():\n",
"\n",
" input_word_ids=tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32, name='input_word_ids')\n",
" input_mask=tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,name='input_mask')\n",
" input_type_ids=tf.keras.layers.Input(shape=(max_seq_length,), dtype=tf.int32,name='input_type_ids')\n",
"\n",
" encoder_outputs = encoder(inputs = {\n",
" 'input_word_ids':input_word_ids,\n",
" 'input_mask': input_mask,\n",
" 'input_type_ids': input_type_ids\n",
" })\n",
" pooled_output = encoder_outputs[\"pooled_output\"] # [batch_size, 768].\n",
" sequence_output = encoder_outputs[\"sequence_output\"] # [batch_size, seq_length, 768].\n",
" drop = tf.keras.layers.Dropout(0.4)(pooled_output) # dropout regularization to avoid overfitting\n",
" output = tf.keras.layers.Dense(1, activation='sigmoid', name = 'output')(drop) # value range [0,1]\n",
"\n",
" model = tf.keras.Model(\n",
" inputs = {\n",
" 'input_word_ids':input_word_ids,\n",
" 'input_mask': input_mask,\n",
" 'input_type_ids': input_type_ids\n",
" },\n",
" outputs = output)\n",
" return model"
],
"execution_count": 31,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "S6maM-vr7YaJ"
},
"source": [
"## Task 10: Fine-Tune BERT for Text Classification"
]
},
{
"cell_type": "code",
"metadata": {
"id": "ptCtiiONsBgo",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "4f49b2f3-298e-4aa5-9538-c21417d74666"
},
"source": [
"model = create_model()\n",
"model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=2e-5),\n",
" loss=tf.losses.BinaryCrossentropy(), # since its just two classes\n",
" metrics=[tf.keras.metrics.BinaryAccuracy()])\n",
"model.summary()"
],
"execution_count": 32,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Model: \"model_1\"\n",
"__________________________________________________________________________________________________\n",
" Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
" input_mask (InputLayer) [(None, 128)] 0 [] \n",
" \n",
" input_type_ids (InputLayer) [(None, 128)] 0 [] \n",
" \n",
" input_word_ids (InputLayer) [(None, 128)] 0 [] \n",
" \n",
" keras_layer (KerasLayer) {'default': (None, 109482241 ['input_mask[0][0]', \n",
" 768), 'input_type_ids[0][0]', \n",
" 'pooled_output': ( 'input_word_ids[0][0]'] \n",
" None, 768), \n",
" 'sequence_output': \n",
" (None, 128, 768), \n",
" 'encoder_outputs': \n",
" [(None, 128, 768), \n",
" (None, 128, 768), \n",
" (None, 128, 768), \n",
" (None, 128, 768), \n",
" (None, 128, 768), \n",
" (None, 128, 768), \n",
" (None, 128, 768), \n",
" (None, 128, 768), \n",
" (None, 128, 768), \n",
" (None, 128, 768), \n",
" (None, 128, 768), \n",
" (None, 128, 768)]} \n",
" \n",
" dropout_1 (Dropout) (None, 768) 0 ['keras_layer[1][13]'] \n",
" \n",
" output (Dense) (None, 1) 769 ['dropout_1[0][0]'] \n",
" \n",
"==================================================================================================\n",
"Total params: 109,483,010\n",
"Trainable params: 109,483,009\n",
"Non-trainable params: 1\n",
"__________________________________________________________________________________________________\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "6GJaFnkbMtPL",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 177
},
"outputId": "64eedd36-a301-4386-a9ff-bd82259eff77"
},
"source": [
"tf.keras.utils.plot_model(model=model, show_shapes=True, dpi=76)"
],
"execution_count": 33,
"outputs": [
{
"output_type": "execute_result",
"data": {
"image/png": "\n",
"text/plain": [
"<IPython.core.display.Image object>"
]
},
"metadata": {},
"execution_count": 33
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "OcREcgPUHr9O",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "3a5cb213-eb63-403e-87b8-edde454bcf08"
},
"source": [
"# Train model\n",
"epochs = 4\n",
"history = model.fit(train_data,\n",
" validation_data = valid_data,\n",
" epochs = epochs,\n",
" verbose=1) # to limit the amount of output"
],
"execution_count": 34,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Epoch 1/4\n",
"306/306 [==============================] - 312s 826ms/step - loss: 0.1701 - binary_accuracy: 0.9409 - val_loss: 0.1255 - val_binary_accuracy: 0.9469\n",
"Epoch 2/4\n",
"306/306 [==============================] - 253s 824ms/step - loss: 0.1014 - binary_accuracy: 0.9597 - val_loss: 0.1516 - val_binary_accuracy: 0.9615\n",
"Epoch 3/4\n",
"306/306 [==============================] - 257s 834ms/step - loss: 0.0557 - binary_accuracy: 0.9793 - val_loss: 0.1824 - val_binary_accuracy: 0.9500\n",
"Epoch 4/4\n",
"306/306 [==============================] - 259s 843ms/step - loss: 0.0248 - binary_accuracy: 0.9918 - val_loss: 0.2473 - val_binary_accuracy: 0.9594\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kNZl1lx_cA5Y"
},
"source": [
"## Task 11: Evaluate the BERT Text Classification Model"
]
},
{
"cell_type": "code",
"metadata": {
"id": "dCjgrUYH_IsE"
},
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"def plot_graphs(history, metric):\n",
" plt.plot(history.history[metric])\n",
" plt.plot(history.history['val_'+metric], '')\n",
" plt.xlabel(\"Epochs\")\n",
" plt.ylabel(metric)\n",
" plt.legend([metric, 'val_'+metric])\n",
" plt.show()"
],
"execution_count": 35,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "v6lrFRra_KmA",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 449
},
"outputId": "27ebb7d4-d1e5-4c40-8124-d183bed4d3b8"
},
"source": [
"plot_graphs(history,'loss')"
],
"execution_count": 36,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "opu9neBA_98R",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 449
},
"outputId": "d97af933-dd92-4406-ce11-56fd9a8fe374"
},
"source": [
"plot_graphs(history,'binary_accuracy')"
],
"execution_count": 37,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "hkhtCCgnUbY6",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "7def0ffd-c786-43a5-ff9a-1e5137697e61"
},
"source": [
"sample_example = [\"May I have your email id?\", \"Elon Mush is not a human.\",\"I want icecream\"]\n",
"test_data = tf.data.Dataset.from_tensor_slices((sample_example, [0]*len(sample_example)))\n",
"test_data = (test_data.map(to_feature_map).batch(1))\n",
"preds = model.predict(test_data)\n",
"threshold = 0.5 # between 0 and 1\n",
"['Insincere' if pred > threshold else 'Sincere' for pred in preds]"
],
"execution_count": 41,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"3/3 [==============================] - 0s 22ms/step\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"['Sincere', 'Sincere', 'Sincere']"
]
},
"metadata": {},
"execution_count": 41
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "K4B8NQBLd9rN"
},
"source": [],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "FeVNOGfFJT9O"
},
"source": [],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "I_YWudFRJT__"
},
"source": [],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "hENB__IlJUCk"
},
"source": [],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "wkYpiGrhJUFK"
},
"source": [],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "iYqbQZJnJUHw"
},
"source": [],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "aiKuBGgfJUKv"
},
"source": [],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment