Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save AwsafAlam/10a70361382dfcc72f1b62e5fd50439a to your computer and use it in GitHub Desktop.
Save AwsafAlam/10a70361382dfcc72f1b62e5fd50439a to your computer and use it in GitHub Desktop.
Contrastive_Learning_ToxiCR_BERT-Basic.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"collapsed_sections": [
"4JrUHXms16cn",
"lgmMzEm_ShK1",
"ORnn8VPz9uLC",
"so-0GYfcLuaf"
],
"toc_visible": true,
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/AwsafAlam/10a70361382dfcc72f1b62e5fd50439a/preprocessing_contrastive_learning_toxicr_bert-v2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EKOTlwcmxmej"
},
"source": [
"# ToxiCR Dataset BERT Fine-Tuning using PyTorch\n",
"\n",
"~ *Modified: Md Awsaf Alam*"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nSU7yERLP_66"
},
"source": [
"### Step 1: Steup\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GI0iOY8zvZzL"
},
"source": [
"GPU Detection"
]
},
{
"cell_type": "code",
"metadata": {
"id": "DEfSbAA4QHas",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "1d792c59-b50e-4cc5-9c18-ff6bf748acbf"
},
"source": [
"import tensorflow as tf\n",
"\n",
"# Get the GPU device name.\n",
"device_name = tf.test.gpu_device_name()\n",
"\n",
"# The device name should look like the following:\n",
"if device_name == '/device:GPU:0':\n",
" print('Found GPU at: {}'.format(device_name))\n",
"else:\n",
" raise SystemError('GPU device not found')"
],
"execution_count": 36,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Found GPU at: /device:GPU:0\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cqG7FzRVFEIv"
},
"source": [
"In order for torch to use the GPU, we need to identify and specify the GPU as the device. Later, in our training loop, we will load data onto the device."
]
},
{
"cell_type": "code",
"metadata": {
"id": "oYsV4H8fCpZ-",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "db306625-7003-497e-e93c-fd368e5a3f71"
},
"source": [
"import torch\n",
"\n",
"# If there's a GPU available...\n",
"if torch.cuda.is_available():\n",
"\n",
" # Tell PyTorch to use the GPU.\n",
" device = torch.device(\"cuda\")\n",
"\n",
" print('There are %d GPU(s) available.' % torch.cuda.device_count())\n",
"\n",
" print('We will use the GPU:', torch.cuda.get_device_name(0))\n",
"\n",
"# If not...\n",
"else:\n",
" print('No GPU available, using the CPU instead.')\n",
" device = torch.device(\"cpu\")"
],
"execution_count": 37,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"There are 1 GPU(s) available.\n",
"We will use the GPU: Tesla T4\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!nvidia-smi"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9gFnA5Ep8TOF",
"outputId": "5c90f2bf-b42f-44f7-c0cd-b1cc84bfef3b"
},
"execution_count": 38,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Sat Aug 19 06:08:22 2023 \n",
"+-----------------------------------------------------------------------------+\n",
"| NVIDIA-SMI 525.105.17 Driver Version: 525.105.17 CUDA Version: 12.0 |\n",
"|-------------------------------+----------------------+----------------------+\n",
"| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
"| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |\n",
"| | | MIG M. |\n",
"|===============================+======================+======================|\n",
"| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |\n",
"| N/A 75C P0 31W / 70W | 1353MiB / 15360MiB | 0% Default |\n",
"| | | N/A |\n",
"+-------------------------------+----------------------+----------------------+\n",
" \n",
"+-----------------------------------------------------------------------------+\n",
"| Processes: |\n",
"| GPU GI CI PID Type Process name GPU Memory |\n",
"| ID ID Usage |\n",
"|=============================================================================|\n",
"+-----------------------------------------------------------------------------+\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"## Clearing cache\n",
"import gc\n",
"# del variables\n",
"gc.collect()\n",
"# del model\n",
"# del tokenizer\n",
"torch.cuda.memory_summary(device=None, abbreviated=False)\n",
"torch.cuda.empty_cache()\n"
],
"metadata": {
"id": "bUmPt74q78kd"
},
"execution_count": 39,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "2ElsnSNUridI"
},
"source": [
"Installing the Hugging Face Library\n"
]
},
{
"cell_type": "code",
"metadata": {
"id": "0NmMdkZO8R6q",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "e7edb376-e1b0-44ab-bf95-c7ad100bd25f"
},
"source": [
"!pip install transformers\n",
"!pip install pytorch_metric_learning"
],
"execution_count": 40,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.31.0)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.2)\n",
"Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.16.4)\n",
"Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.23.5)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)\n",
"Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n",
"Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.13.3)\n",
"Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.3.2)\n",
"Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.1)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (2023.6.0)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.7.1)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.2.0)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.4)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.7.22)\n",
"Requirement already satisfied: pytorch_metric_learning in /usr/local/lib/python3.10/dist-packages (2.3.0)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from pytorch_metric_learning) (1.23.5)\n",
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from pytorch_metric_learning) (1.2.2)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from pytorch_metric_learning) (4.66.1)\n",
"Requirement already satisfied: torch>=1.6.0 in /usr/local/lib/python3.10/dist-packages (from pytorch_metric_learning) (2.0.1+cu118)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->pytorch_metric_learning) (3.12.2)\n",
"Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->pytorch_metric_learning) (4.7.1)\n",
"Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->pytorch_metric_learning) (1.12)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->pytorch_metric_learning) (3.1)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->pytorch_metric_learning) (3.1.2)\n",
"Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.6.0->pytorch_metric_learning) (2.0.0)\n",
"Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.6.0->pytorch_metric_learning) (3.27.2)\n",
"Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch>=1.6.0->pytorch_metric_learning) (16.0.6)\n",
"Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->pytorch_metric_learning) (1.10.1)\n",
"Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->pytorch_metric_learning) (1.3.2)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->pytorch_metric_learning) (3.2.0)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.6.0->pytorch_metric_learning) (2.1.3)\n",
"Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.6.0->pytorch_metric_learning) (1.3.0)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import torch\n",
"import os\n",
"import random\n",
"import time\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from pytorch_metric_learning import losses as loss_fun\n",
"from pytorch_metric_learning.distances import CosineSimilarity\n",
"from pytorch_metric_learning.reducers import ThresholdReducer\n",
"from pytorch_metric_learning.regularizers import LpRegularizer\n",
"\n",
"from transformers import BertTokenizer, ElectraTokenizer, ElectraForSequenceClassification, ElectraModel\n",
"from transformers import BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup, BertModel\n",
"\n",
"import torch.nn.functional as F\n",
"import torch.nn as nn\n",
"\n",
"from torch.utils.data import Dataset\n",
"from torch.utils.data.dataloader import DataLoader\n",
"from torch import nn, optim\n",
"\n",
"EPOCHS = 4\n",
"BATCH_SIZE = 16\n",
"# MAX_LEN_LIST = [200,160]\n",
"undersample = [False]\n",
"learn_rate_list = [5e-5,3e-5,1e-5,8e-6]\n",
"\n",
"# Set the seed value all over the place to make this reproducible.\n",
"seed_val = 42\n",
"\n",
"random.seed(seed_val)\n",
"np.random.seed(seed_val)\n",
"torch.manual_seed(seed_val)\n",
"torch.cuda.manual_seed_all(seed_val)\n",
"\n"
],
"metadata": {
"id": "BP3_FiOcesNX"
},
"execution_count": 41,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"Helper functions"
],
"metadata": {
"id": "e5CN-SsEzL2d"
}
},
{
"cell_type": "code",
"metadata": {
"id": "gpt6tR83keZD"
},
"source": [
"import time\n",
"import datetime\n",
"import numpy as np\n",
"\n",
"# Function to calculate the accuracy of our predictions vs labels\n",
"def flat_accuracy(preds, labels):\n",
" pred_flat = np.argmax(preds, axis=1).flatten()\n",
" labels_flat = labels.flatten()\n",
" return np.sum(pred_flat == labels_flat) / len(labels_flat)\n",
"\n",
"def format_time(elapsed):\n",
" '''\n",
" Takes a time in seconds and returns a string hh:mm:ss\n",
" '''\n",
" # Round to the nearest second.\n",
" elapsed_rounded = int(round((elapsed)))\n",
"\n",
" # Format as hh:mm:ss\n",
" return str(datetime.timedelta(seconds=elapsed_rounded))\n"
],
"execution_count": 42,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "4JrUHXms16cn"
},
"source": [
"### Step 2: Preprocessing"
]
},
{
"cell_type": "markdown",
"source": [
"##### Init Preprocessing Functions"
],
"metadata": {
"id": "lgmMzEm_ShK1"
}
},
{
"cell_type": "code",
"source": [
"import re\n",
"\n",
"url_regex = re.compile('http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+')\n",
"\n",
"\n",
"def remove_url(text):\n",
" return url_regex.sub(\" \", text)\n",
"\n",
"\n",
"def rem_special_sym(text):\n",
"# return re.sub('\\W+',' ', text)\n",
" pattern = re.compile('([^\\s\\w]|_)+')\n",
" return pattern.sub(' ', text)\n"
],
"metadata": {
"id": "-BxnpOeZSlfD"
},
"execution_count": 45,
"outputs": []
},
{
"cell_type": "code",
"source": [
"contraction_mapping = {\"ain't\": \"is not\", \"aren't\": \"are not\",\n",
" \"can't\": \"cannot\", \"'cause\": \"because\",\n",
" \"could've\": \"could have\", \"couldn't\": \"could not\",\n",
" \"didn't\": \"did not\", \"doesn't\": \"does not\",\n",
" \"don't\": \"do not\", \"hadn't\": \"had not\", \"hasn't\": \"has not\",\n",
" \"haven't\": \"have not\", \"he'd\": \"he would\", \"he'll\": \"he will\",\n",
" \"he's\": \"he is\", \"how'd\": \"how did\", \"how'd'y\": \"how do you\",\n",
" \"how'll\": \"how will\", \"how's\": \"how is\", \"I'd\": \"I would\",\n",
" \"I'd've\": \"I would have\", \"I'll\": \"I will\", \"I'll've\": \"I will have\",\n",
" \"I'm\": \"I am\", \"I've\": \"I have\", \"i'd\": \"i would\", \"i'd've\": \"i would have\",\n",
" \"i'll\": \"i will\", \"i'll've\": \"i will have\", \"i'm\": \"i am\",\n",
" \"i've\": \"i have\", \"isn't\": \"is not\", \"it'd\": \"it would\",\n",
" \"it'd've\": \"it would have\", \"it'll\": \"it will\", \"it'll've\": \"it will have\",\n",
" \"it's\": \"it is\", \"let's\": \"let us\", \"ma'am\": \"madam\", \"mayn't\": \"may not\",\n",
" \"might've\": \"might have\", \"mightn't\": \"might not\",\n",
" \"mightn't've\": \"might not have\", \"must've\": \"must have\",\n",
" \"mustn't\": \"must not\", \"mustn't've\": \"must not have\",\n",
" \"needn't\": \"need not\", \"needn't've\": \"need not have\",\n",
" \"o'clock\": \"of the clock\", \"oughtn't\": \"ought not\",\n",
" \"oughtn't've\": \"ought not have\", \"shan't\": \"shall not\",\n",
" \"sha'n't\": \"shall not\", \"shan't've\": \"shall not have\",\n",
" \"she'd\": \"she would\", \"she'd've\": \"she would have\",\n",
" \"she'll\": \"she will\", \"she'll've\": \"she will have\",\n",
" \"she's\": \"she is\", \"should've\": \"should have\", \"shouldn't\": \"should not\",\n",
" \"shouldn't've\": \"should not have\", \"so've\": \"so have\", \"so's\": \"so as\",\n",
" \"this's\": \"this is\", \"that'd\": \"that would\", \"that'd've\": \"that would have\",\n",
" \"that's\": \"that is\", \"there'd\": \"there would\",\n",
" \"there'd've\": \"there would have\", \"there's\": \"there is\",\n",
" \"here's\": \"here is\", \"they'd\": \"they would\", \"they'd've\": \"they would have\",\n",
" \"they'll\": \"they will\", \"they'll've\": \"they will have\", \"they're\": \"they are\",\n",
" \"they've\": \"they have\", \"to've\": \"to have\", \"wasn't\": \"was not\", \"we'd\": \"we would\",\n",
" \"we'd've\": \"we would have\", \"we'll\": \"we will\", \"we'll've\": \"we will have\",\n",
" \"we're\": \"we are\", \"we've\": \"we have\", \"weren't\": \"were not\",\n",
" \"what'll\": \"what will\",\n",
" \"what'll've\": \"what will have\", \"what're\": \"what are\", \"what's\": \"what is\",\n",
" \"what've\": \"what have\", \"when's\": \"when is\", \"when've\": \"when have\",\n",
" \"where'd\": \"where did\", \"where's\": \"where is\", \"where've\": \"where have\",\n",
" \"who'll\": \"who will\", \"who'll've\": \"who will have\", \"who's\": \"who is\",\n",
" \"who've\": \"who have\", \"why's\": \"why is\", \"why've\": \"why have\",\n",
" \"will've\": \"will have\", \"won't\": \"will not\", \"won't've\": \"will not have\",\n",
" \"would've\": \"would have\", \"wouldn't\": \"would not\", \"wouldn't've\": \"would not have\",\n",
" \"y'all\": \"you all\", \"y'all'd\": \"you all would\", \"y'all'd've\": \"you all would have\",\n",
" \"y'all're\": \"you all are\", \"y'all've\": \"you all have\", \"you'd\": \"you would\",\n",
" \"you'd've\": \"you would have\", \"you'll\": \"you will\", \"you'll've\": \"you will have\",\n",
" \"you're\": \"you are\", \"you've\": \"you have\", \"aint\": \"is not\", \"arent\": \"are not\",\n",
" \"cant\": \"cannot\", \"cause\": \"because\",\n",
" \"couldve\": \"could have\", \"couldnt\": \"could not\",\n",
" \"didnt\": \"did not\", \"doesnt\": \"does not\",\n",
" \"dont\": \"do not\", \"hadnt\": \"had not\", \"hasnt\": \"has not\",\n",
" \"havent\": \"have not\", \"howdy\": \"how do you\",\n",
" \"its\": \"it is\", \"lets\": \"let us\", \"maam\": \"madam\", \"maynt\": \"may not\",\n",
" \"mightve\": \"might have\", \"mightnt\": \"might not\",\n",
" \"mightntve\": \"might not have\", \"mustve\": \"must have\",\n",
" \"mustnt\": \"must not\", \"mustntve\": \"must not have\",\n",
" \"neednt\": \"need not\", \"needntve\": \"need not have\",\n",
" \"oclock\": \"of the clock\", \"oughtnt\": \"ought not\",\n",
" \"shouldve\": \"should have\", \"shouldnt\": \"should not\",\n",
" \"werent\": \"were not\", \"yall\": \"you all\", \"youre\": \"you are\",\n",
" \"youve\": \"you have\"}\n",
"\n",
"def expand_contraction(text):\n",
" specials = [\"’\", \"‘\", \"´\", \"`\", \"'\"]\n",
"\n",
" for s in specials:\n",
" text = text.replace(s, \"'\")\n",
" text = ' '.join([contraction_mapping[t] if t in contraction_mapping else t for t in text.split(\" \")])\n",
" return text\n"
],
"metadata": {
"id": "48Z_Ajo8SwPS"
},
"execution_count": 46,
"outputs": []
},
{
"cell_type": "code",
"source": [
"RE_PATTERNS = {\n",
"\n",
" ' fuck ':\n",
" [\n",
" '(f)(u|[^a-z0-9 ])(c|[^a-z0-9 ])(k|[^a-z0-9 ])([^ ])*',\n",
" '(f)([^a-z]*)(u)([^a-z]*)(c)([^a-z]*)(k)',\n",
" ' f[!@#\\$%\\^\\&\\*]*u[!@#\\$%\\^&\\*]*k', 'f u u c',\n",
" '(f)(c|[^a-z ])(u|[^a-z ])(k)', r'f\\*',\n",
" 'feck ', ' fux ', 'f\\*\\*',\n",
" 'f\\-ing', 'f\\.u\\.', 'f###', ' fu ', 'f@ck', 'f u c k', 'f uck', 'f ck'\n",
"\n",
" ],\n",
"\n",
" ' crap ':\n",
" [\n",
" ' (c)(r|[^a-z0-9 ])(a|[^a-z0-9 ])(p|[^a-z0-9 ])([^ ])*',\n",
" ' (c)([^a-z]*)(r)([^a-z]*)(a)([^a-z]*)(p)',\n",
" ' c[!@#\\$%\\^\\&\\*]*r[!@#\\$%\\^&\\*]*p', 'cr@p', ' c r a p',\n",
"\n",
" ],\n",
"\n",
" ' ass ':\n",
" [\n",
" '[^a-z]ass ', '[^a-z]azz ', 'arrse', ' arse ', '@\\$\\$'\n",
" '[^a-z]anus', ' a\\*s\\*s', '[^a-z]ass[^a-z ]',\n",
" 'a[@#\\$%\\^&\\*][@#\\$%\\^&\\*]', '[^a-z]anal ', 'a s s'\n",
" ],\n",
"\n",
" ' ass hole ':\n",
" [\n",
" ' a[s|z]*wipe', 'a[s|z]*[w]*h[o|0]+[l]*e', '@\\$\\$hole'\n",
" ],\n",
"\n",
" ' bitch ':\n",
" [\n",
" 'bitches', ' b[w]*i[t]*ch', ' b!tch',\n",
" ' bi\\+ch', ' b!\\+ch', ' (b)([^a-z]*)(i)([^a-z]*)(t)([^a-z]*)(c)([^a-z]*)(h)',\n",
" ' biatch', ' bi\\*\\*h', ' bytch', 'b i t c h'\n",
" ],\n",
"\n",
" ' bastard ':\n",
" [\n",
" 'ba[s|z]+t[e|a]+rd'\n",
" ],\n",
"\n",
" ' transgender':\n",
" [\n",
" 'transgender'\n",
" ],\n",
"\n",
" ' gay ':\n",
" [\n",
" 'gay', 'homo'\n",
" ],\n",
"\n",
" ' cock ':\n",
" [\n",
" '[^a-z]cock', 'c0ck', '[^a-z]cok ', 'c0k', '[^a-z]cok[^aeiou]', ' cawk',\n",
" '(c)([^a-z ])(o)([^a-z ]*)(c)([^a-z ]*)(k)', 'c o c k'\n",
" ],\n",
"\n",
" ' dick ':\n",
" [\n",
" ' dick[^aeiou]', 'd i c k'\n",
" ],\n",
"\n",
" ' suck ':\n",
" [\n",
" 'sucker', '(s)([^a-z ]*)(u)([^a-z ]*)(c)([^a-z ]*)(k)', 'sucks', '5uck', 's u c k'\n",
" ],\n",
"\n",
" ' cunt ':\n",
" [\n",
" 'cunt', 'c u n t'\n",
" ],\n",
"\n",
" ' bull shit ':\n",
" [\n",
" 'bullsh\\*t', 'bull\\$hit', 'bull sh.t'\n",
" ],\n",
"\n",
" ' jerk ':\n",
" [\n",
" 'jerk'\n",
" ],\n",
"\n",
" ' idiot ':\n",
" [\n",
" 'i[d]+io[t]+', '(i)([^a-z ]*)(d)([^a-z ]*)(i)([^a-z ]*)(o)([^a-z ]*)(t)', 'idiots' 'i d i o t'\n",
" ],\n",
"\n",
" ' dumb ':\n",
" [\n",
" '(d)([^a-z ]*)(u)([^a-z ]*)(m)([^a-z ]*)(b)'\n",
" ],\n",
"\n",
" ' shit ':\n",
" [\n",
" 'shitty', '(s)([^a-z ]*)(h)([^a-z ]*)(i)([^a-z ]*)(t)', 'shite', '\\$hit', 's h i t', 'sh\\*tty',\n",
" 'sh\\*ty', 'sh\\*t'\n",
" ],\n",
"\n",
" ' shit hole ':\n",
" [\n",
" 'shythole', 'sh.thole'\n",
" ],\n",
"\n",
" ' retard ':\n",
" [\n",
" 'returd', 'retad', 'retard', 'wiktard', 'wikitud'\n",
" ],\n",
"\n",
" ' rape ':\n",
" [\n",
" 'raped'\n",
" ],\n",
"\n",
" ' dumb ass':\n",
" [\n",
" 'dumbass', 'dubass'\n",
" ],\n",
"\n",
" ' ass head':\n",
" [\n",
" 'butthead'\n",
" ],\n",
"\n",
" ' sex ':\n",
" [\n",
" 'sexy', 's3x', 'sexuality'\n",
" ],\n",
"\n",
" ' nigger ':\n",
" [\n",
" 'nigger', 'ni[g]+a', ' nigr ', 'negrito', 'niguh', 'n3gr', 'n i g g e r'\n",
" ],\n",
"\n",
" ' shut the fuck up':\n",
" [\n",
" ' stfu' '^stfu'\n",
" ],\n",
"\n",
" ' for your fucking information':\n",
" [\n",
" ' fyfi', '^fyfi'\n",
" ],\n",
" ' get the fuck off':\n",
" [\n",
" 'gtfo', '^gtfo'\n",
" ],\n",
"\n",
" ' oh my fucking god ':\n",
" [\n",
" ' omfg', '^omfg'\n",
" ],\n",
"\n",
" ' what the hell ':\n",
" [\n",
" ' wth', '^wth'\n",
" ],\n",
"\n",
" ' what the fuck ':\n",
" [\n",
" ' wtf', '^wtf'\n",
" ],\n",
" ' son of bitch ':\n",
" [\n",
" ' sob ', '^sob '\n",
" ],\n",
"\n",
" ' pussy ':\n",
" [\n",
" 'pussy[^c]', 'pusy', 'pussi[^l]', 'pusses', '(p)(u|[^a-z0-9 ])(s|[^a-z0-9 ])(s|[^a-z0-9 ])(y)',\n",
" ],\n",
"\n",
" ' faggot ':\n",
" [\n",
" 'faggot', ' fa[g]+[s]*[^a-z ]', 'fagot', 'f a g g o t', 'faggit',\n",
" '(f)([^a-z ]*)(a)([^a-z ]*)([g]+)([^a-z ]*)(o)([^a-z ]*)(t)', 'fau[g]+ot', 'fae[g]+ot',\n",
" ],\n",
"\n",
" ' mother fucker':\n",
" [\n",
" ' motha f', ' mother f', 'motherucker', ' mofo', ' mf ',\n",
" ],\n",
"\n",
" ' whore ':\n",
" [\n",
" 'wh\\*\\*\\*', 'w h o r e'\n",
" ],\n",
"\n",
" ' haha ':\n",
" [\n",
" 'ha\\*\\*\\*ha',\n",
" ],\n",
" # ' what the fuck ':\n",
" # [\n",
" # ' wtf',\n",
" # ],\n",
"}\n",
"\n",
"patterns = RE_PATTERNS\n",
"initial_filters=r\"[^a-z0-9!@#\\$%\\^\\*\\+\\?\\&\\_\\-,\\.' ]\"\n",
"lower=True\n",
"remove_repetitions=True"
],
"metadata": {
"id": "y3SY7-8kUPEf"
},
"execution_count": 47,
"outputs": []
},
{
"cell_type": "code",
"source": [
"def process_profanity(text):\n",
" # lower\n",
" if lower:\n",
" text = text.lower()\n",
"\n",
" # remove special chars\n",
" if initial_filters is not None:\n",
" text = re.sub(initial_filters, ' ', text)\n",
"\n",
" # neeeeeeeeeerd => nerd\n",
" if remove_repetitions:\n",
" pattern = re.compile(r\"(.)\\1{2,}\", re.DOTALL)\n",
" text = pattern.sub(r\"\\1\", text)\n",
"\n",
" x = text\n",
" for target, pattern in patterns.items():\n",
" for pat in pattern:\n",
" x = re.sub(pat, target, x)\n",
" x = re.sub(r\"[^a-z' ]\", ' ', x)\n",
" return x\n",
"\n"
],
"metadata": {
"id": "_3KTYbdoUhdr"
},
"execution_count": 48,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"#### Loading Data"
],
"metadata": {
"id": "0T1zN5BboiBp"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "3ZNVW6xd0T0X"
},
"source": [
"We'll use the `wget` package to download the dataset to the Colab instance's file system."
]
},
{
"cell_type": "code",
"metadata": {
"id": "5m6AnuFv0QXQ",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "291dbae1-227b-4e47-909d-9dd31126f8bc"
},
"source": [
"!pip install wget"
],
"execution_count": 43,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Requirement already satisfied: wget in /usr/local/lib/python3.10/dist-packages (3.2)\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "08pO03Ff1BjI"
},
"source": [
"Downloading ToxiCR daatset: https://github.com/WSU-SEAL/ToxiCR"
]
},
{
"cell_type": "code",
"metadata": {
"id": "pMtmPMkBzrvs",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "642b461f-2e0b-4937-bdcf-f8c1167af2e6"
},
"source": [
"import wget\n",
"import os\n",
"\n",
"print('Downloading dataset...')\n",
"\n",
"# The URL for the dataset zip file.\n",
"url = 'https://github.com/WSU-SEAL/ToxiCR/blob/master/models/code-review-dataset-full.xlsx?raw=true'\n",
"\n",
"# Download the file (if we haven't already)\n",
"if not os.path.exists('./code-review-dataset-full.xlsx'):\n",
" wget.download(url, './code-review-dataset-full.xlsx')\n",
"\n"
],
"execution_count": 44,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Downloading dataset...\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "0Yv-tNv20dnH",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 417
},
"outputId": "11558882-385d-43ab-8e0c-12f19539ea12"
},
"source": [
"import pandas as pd\n",
"\n",
"def read_dataframe_from_excel(file):\n",
" print(\"Reading dataframe\")\n",
" dataframe = pd.read_excel(file)\n",
" return dataframe\n",
"\n",
"\n",
"# training_data_df = read_dataframe_from_excel('./code-review-dataset-full.xlsx')\n",
"df = read_dataframe_from_excel('./code-review-dataset-full.xlsx')\n",
"\n",
"\n",
"# Report the number of sentences.\n",
"print('Number of training sentences: {:,}\\n'.format(df.shape[0]))\n",
"\n",
"\n",
"# Display 10 random rows from the data.\n",
"df.sample(10)"
],
"execution_count": 49,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Reading dataframe\n",
"Number of training sentences: 19,651\n",
"\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
" message is_toxic\n",
"10504 While you're here fixing, drop these whitespac... 0\n",
"10882 Put it inside #ifdef SUPPORT_CHECKSUM 0\n",
"18463 This metdata_signature_size should not be incl... 0\n",
"7567 Remove linebreak 0\n",
"8692 lol, \"sucks\" is probably not the right word fo... 1\n",
"12395 If there is one by SLO segment, will this effe... 1\n",
"1535 the !sURL.isEmpty() part is redundant now and ... 0\n",
"19285 '==' is not case insensitive, so that is not a... 0\n",
"10455 I'm trying to structure data and you ask me to... 1\n",
"1815 You probably don't want two Change-Id lines in... 0"
],
"text/html": [
"\n",
" <div id=\"df-4336a41f-af51-481d-8879-d2900844b854\" class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>message</th>\n",
" <th>is_toxic</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>10504</th>\n",
" <td>While you're here fixing, drop these whitespac...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10882</th>\n",
" <td>Put it inside #ifdef SUPPORT_CHECKSUM</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18463</th>\n",
" <td>This metdata_signature_size should not be incl...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7567</th>\n",
" <td>Remove linebreak</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8692</th>\n",
" <td>lol, \"sucks\" is probably not the right word fo...</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12395</th>\n",
" <td>If there is one by SLO segment, will this effe...</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1535</th>\n",
" <td>the !sURL.isEmpty() part is redundant now and ...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19285</th>\n",
" <td>'==' is not case insensitive, so that is not a...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10455</th>\n",
" <td>I'm trying to structure data and you ask me to...</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1815</th>\n",
" <td>You probably don't want two Change-Id lines in...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <div class=\"colab-df-buttons\">\n",
"\n",
" <div class=\"colab-df-container\">\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-4336a41f-af51-481d-8879-d2900844b854')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
"\n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
" <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
" </svg>\n",
" </button>\n",
"\n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" .colab-df-buttons div {\n",
" margin-bottom: 4px;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-4336a41f-af51-481d-8879-d2900844b854 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-4336a41f-af51-481d-8879-d2900844b854');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
"\n",
"\n",
"<div id=\"df-ff583c6f-e3bd-49a2-a256-8eb78524aaf7\">\n",
" <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-ff583c6f-e3bd-49a2-a256-8eb78524aaf7')\"\n",
" title=\"Suggest charts.\"\n",
" style=\"display:none;\">\n",
"\n",
"<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <g>\n",
" <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
" </g>\n",
"</svg>\n",
" </button>\n",
"\n",
"<style>\n",
" .colab-df-quickchart {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-quickchart:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-quickchart {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-quickchart:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
"</style>\n",
"\n",
" <script>\n",
" async function quickchart(key) {\n",
" const charts = await google.colab.kernel.invokeFunction(\n",
" 'suggestCharts', [key], {});\n",
" }\n",
" (() => {\n",
" let quickchartButtonEl =\n",
" document.querySelector('#df-ff583c6f-e3bd-49a2-a256-8eb78524aaf7 button');\n",
" quickchartButtonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
" })();\n",
" </script>\n",
"</div>\n",
" </div>\n",
" </div>\n"
]
},
"metadata": {},
"execution_count": 49
}
]
},
{
"cell_type": "code",
"source": [
"import torch.utils.data as data\n",
"\n",
"train_size = int(0.99 * df.shape[0])\n",
"test_size = df.shape[0] - train_size\n",
"\n",
"# Divide the dataset by randomly selecting samples.\n",
"train_subset, test_subset = data.random_split(df, [train_size, test_size])\n",
"\n",
"train_df = df.iloc[train_subset.indices]\n",
"test_df = df.iloc[test_subset.indices]\n",
"\n",
"\n",
"print('{:>5,} training samples'.format(train_size))\n",
"print('{:>5,} test samples'.format(test_size))\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mbBoLYUpowPy",
"outputId": "a7cc8a89-9bb5-4c29-c3b9-8c249525695b"
},
"execution_count": 50,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"19,454 training samples\n",
" 197 test samples\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# writing to Excel\n",
"# determining the name of the file\n",
"file_name = 'test-set.xlsx'\n",
"datatoexcel = pd.ExcelWriter(file_name)\n",
"\n",
"\n",
"# saving the excel\n",
"# test_df.to_excel(file_name)\n",
"\n",
"# write DataFrame to excel\n",
"test_df.to_excel(datatoexcel)\n",
"\n",
"# save the excel\n",
"datatoexcel.save()\n",
"print('DataFrame is written to Excel File successfully.')"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "19ckndF9YeZQ",
"outputId": "344624fd-e871-43b1-bd5a-a681a23e6200"
},
"execution_count": 51,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"DataFrame is written to Excel File successfully.\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"<ipython-input-51-50224c78287f>:14: FutureWarning: save is not part of the public API, usage can give unexpected results and will be removed in a future version\n",
" datatoexcel.save()\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Mount Google Drive to this Notebook instance.\n",
"from google.colab import drive\n",
"\n",
"drive.mount('/content/drive')\n",
"\n",
"# Copy the model files to a directory in your Google Drive.\n",
"!cp -r ./test-set.xlsx ./drive/MyDrive"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "qtkVPM3uZraA",
"outputId": "1fd5bb83-1e31-4187-999a-4a6974e1dd98"
},
"execution_count": 52,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"#### Preprocess training data"
],
"metadata": {
"id": "ob5PXap-eK2E"
}
},
{
"cell_type": "code",
"source": [
"def process_text(text):\n",
" processed_text = remove_url(text)\n",
" processed_text = expand_contraction(processed_text)\n",
" processed_text = process_profanity(processed_text)\n",
" processed_text = rem_special_sym(processed_text)\n",
" return processed_text\n",
"\n",
"\n",
"train_df['modified_message'] = train_df.astype(str).apply(lambda row : process_text(row['message']), axis = 1)\n",
"print(train_df.sample(10))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Qf2NiBdzSzm7",
"outputId": "505f5fcc-befd-456a-f8b9-9fa9e1294157"
},
"execution_count": 53,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" message is_toxic \\\n",
"4250 These CREVIEW tags should not be here 0 \n",
"830 Yes we can LOL. *Smacks head*. That piece of c... 0 \n",
"1915 Should these be nodebugwaitattach instead? You... 0 \n",
"6646 Yeah, but it affects the outcome of the method... 0 \n",
"3133 Unfortunately yes.\\n(1) I want to time just Ru... 0 \n",
"16402 nit: alignment here is off, project_id should ... 0 \n",
"12308 This should not be needed, get_project raises ... 0 \n",
"1029 Should be a BoolOpt instead of StrOpt. 0 \n",
"2905 Should be a VolumeDriverException rather than ... 0 \n",
"988 amqp_names = $amqp_ipaddresses,\\n \\n amqp_ipad... 0 \n",
"\n",
" modified_message \n",
"4250 these creview tags should not be here \n",
"830 yes we can lol smacks head that piece of c... \n",
"1915 should these be nodebugwaitattach instead you... \n",
"6646 yeah but it affects the outcome of the method... \n",
"3133 unfortunately yes i want to time just run... \n",
"16402 nit alignment here is off project id should ... \n",
"12308 this should not be needed get project raises ... \n",
"1029 should be a boolopt instead of stropt \n",
"2905 should be a volumedriverexception rather than ... \n",
"988 amqp names amqp ipaddresses amqp ipaddresses... \n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"<ipython-input-53-bb8a7ebbb9b2>:9: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" train_df['modified_message'] = train_df.astype(str).apply(lambda row : process_text(row['message']), axis = 1)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Get the lists of sentences and their labels.\n",
"# sample = train_df.sample(1000)\n",
"sentences = train_df.modified_message.values\n",
"labels = train_df.is_toxic.values\n",
"\n",
"# sentences = sample.message.values\n",
"# labels = sample.is_toxic.values"
],
"metadata": {
"id": "01plXrb_7awh"
},
"execution_count": 54,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "oQUy9Tat2EF_"
},
"source": [
"### Step 3: Contrastive Learning"
]
},
{
"cell_type": "markdown",
"source": [
"#### Initialize BERT"
],
"metadata": {
"id": "AM28Zk7RcB_t"
}
},
{
"cell_type": "code",
"metadata": {
"id": "nskPzUM084zL",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "683e411f-91dd-4a5c-92c2-dfb140db9ea2"
},
"source": [
"# Load a trained model and vocabulary that you have fine-tuned\n",
"from transformers import BertTokenizer\n",
"from transformers import BertForSequenceClassification, AdamW, BertConfig\n",
"\n",
"# Load the BERT tokenizer.\n",
"print('Loading BERT tokenizer...')\n",
"\n",
"tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)\n",
"# tokenizer = BertTokenizer.from_pretrained(output_dir)\n",
"\n",
"# Load BertForSequenceClassification, the pretrained BERT model with a single\n",
"# linear classification layer on top.\n",
"# Load BertForSequenceClassification, the pretrained BERT model with a single\n",
"# linear classification layer on top.\n",
"model = BertForSequenceClassification.from_pretrained(\n",
" \"bert-base-uncased\", # Use the 12-layer BERT model, with an uncased vocab.\n",
" num_labels = 2, # The number of output labels--2 for binary classification.\n",
" # You can increase this for multi-class tasks.\n",
" output_attentions = False, # Whether the model returns attentions weights.\n",
" output_hidden_states = True, # Whether the model returns all hidden-states.\n",
")\n",
"\n",
"\n",
"# Copy the model to the GPU.\n",
"model.to(device)"
],
"execution_count": 55,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Loading BERT tokenizer...\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"BertForSequenceClassification(\n",
" (bert): BertModel(\n",
" (embeddings): BertEmbeddings(\n",
" (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
" (position_embeddings): Embedding(512, 768)\n",
" (token_type_embeddings): Embedding(2, 768)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (encoder): BertEncoder(\n",
" (layer): ModuleList(\n",
" (0-11): 12 x BertLayer(\n",
" (attention): BertAttention(\n",
" (self): BertSelfAttention(\n",
" (query): Linear(in_features=768, out_features=768, bias=True)\n",
" (key): Linear(in_features=768, out_features=768, bias=True)\n",
" (value): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" (output): BertSelfOutput(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" (intermediate): BertIntermediate(\n",
" (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
" (intermediate_act_fn): GELUActivation()\n",
" )\n",
" (output): BertOutput(\n",
" (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" )\n",
" )\n",
" )\n",
" )\n",
" (pooler): BertPooler(\n",
" (dense): Linear(in_features=768, out_features=768, bias=True)\n",
" (activation): Tanh()\n",
" )\n",
" )\n",
" (dropout): Dropout(p=0.1, inplace=False)\n",
" (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
")"
]
},
"metadata": {},
"execution_count": 55
}
]
},
{
"cell_type": "markdown",
"source": [
"We define model parameters"
],
"metadata": {
"id": "IW_IKAxudnZd"
}
},
{
"cell_type": "code",
"source": [
"param_optimizer = list(model.named_parameters())\n",
"no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n",
"optimizer_grouped_parameters = [\n",
" {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},\n",
" {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay':0.0}]\n",
"optimizer = AdamW(optimizer_grouped_parameters,\n",
" lr=learn_rate_list[0], # args.learning_rate - default is 5e-5, our notebook had 2e-5\n",
" eps = 1e-8 # args.adam_epsilon - default is 1e-8.\n",
" )\n",
"\n",
"\n",
"# optimizer = AdamW(model.parameters(),\n",
"# lr = 2e-5, # args.learning_rate - default is 5e-5, our notebook had 2e-5\n",
"# eps = 1e-8 # args.adam_epsilon - default is 1e-8.\n",
"# )\n"
],
"metadata": {
"id": "5QUzeWzZajYh",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "3932f7e0-8b62-4173-896f-51831e986fd9"
},
"execution_count": 56,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
" warnings.warn(\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"Model parameters"
],
"metadata": {
"id": "sCvgmC7C1Deq"
}
},
{
"cell_type": "code",
"metadata": {
"id": "8PIiVlDYCtSq",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "ab305f08-11c5-4d98-e0d9-381812372a65"
},
"source": [
"# Get all of the model's parameters as a list of tuples.\n",
"params = list(model.named_parameters())\n",
"\n",
"print('The BERT model has {:} different named parameters.\\n'.format(len(params)))\n",
"\n",
"print('==== Embedding Layer ====\\n')\n",
"\n",
"for p in params[0:5]:\n",
" print(\"{:<55} {:>12}\".format(p[0], str(tuple(p[1].size()))))\n",
"\n",
"print('\\n==== First Transformer ====\\n')\n",
"\n",
"for p in params[5:21]:\n",
" print(\"{:<55} {:>12}\".format(p[0], str(tuple(p[1].size()))))\n",
"\n",
"print('\\n==== Output Layer ====\\n')\n",
"\n",
"for p in params[-4:]:\n",
" print(\"{:<55} {:>12}\".format(p[0], str(tuple(p[1].size()))))"
],
"execution_count": 57,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"The BERT model has 201 different named parameters.\n",
"\n",
"==== Embedding Layer ====\n",
"\n",
"bert.embeddings.word_embeddings.weight (30522, 768)\n",
"bert.embeddings.position_embeddings.weight (512, 768)\n",
"bert.embeddings.token_type_embeddings.weight (2, 768)\n",
"bert.embeddings.LayerNorm.weight (768,)\n",
"bert.embeddings.LayerNorm.bias (768,)\n",
"\n",
"==== First Transformer ====\n",
"\n",
"bert.encoder.layer.0.attention.self.query.weight (768, 768)\n",
"bert.encoder.layer.0.attention.self.query.bias (768,)\n",
"bert.encoder.layer.0.attention.self.key.weight (768, 768)\n",
"bert.encoder.layer.0.attention.self.key.bias (768,)\n",
"bert.encoder.layer.0.attention.self.value.weight (768, 768)\n",
"bert.encoder.layer.0.attention.self.value.bias (768,)\n",
"bert.encoder.layer.0.attention.output.dense.weight (768, 768)\n",
"bert.encoder.layer.0.attention.output.dense.bias (768,)\n",
"bert.encoder.layer.0.attention.output.LayerNorm.weight (768,)\n",
"bert.encoder.layer.0.attention.output.LayerNorm.bias (768,)\n",
"bert.encoder.layer.0.intermediate.dense.weight (3072, 768)\n",
"bert.encoder.layer.0.intermediate.dense.bias (3072,)\n",
"bert.encoder.layer.0.output.dense.weight (768, 3072)\n",
"bert.encoder.layer.0.output.dense.bias (768,)\n",
"bert.encoder.layer.0.output.LayerNorm.weight (768,)\n",
"bert.encoder.layer.0.output.LayerNorm.bias (768,)\n",
"\n",
"==== Output Layer ====\n",
"\n",
"bert.pooler.dense.weight (768, 768)\n",
"bert.pooler.dense.bias (768,)\n",
"classifier.weight (2, 768)\n",
"classifier.bias (2,)\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"Define constants and imports"
],
"metadata": {
"id": "1II6BA7KepwN"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "l6w8elb-58GJ"
},
"source": [
"#### Tokenize Dataset"
]
},
{
"cell_type": "code",
"metadata": {
"id": "dLIbudgfh6F0",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "fc67b442-d2bd-45db-c00e-bbfeb972347f"
},
"source": [
"# Print the original sentence.\n",
"print(' Original: ', sentences[1])\n",
"\n",
"# Print the sentence split into tokens.\n",
"print('Tokenized: ', tokenizer.tokenize(sentences[1]))\n",
"\n",
"# Print the sentence mapped to token ids.\n",
"print('Token IDs: ', tokenizer.convert_tokens_to_ids(tokenizer.tokenize(sentences[1])))"
],
"execution_count": 58,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" Original: these paren make no sense delete them\n",
"Tokenized: ['these', 'par', '##en', 'make', 'no', 'sense', 'del', '##ete', 'them']\n",
"Token IDs: [2122, 11968, 2368, 2191, 2053, 3168, 3972, 12870, 2068]\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "cKsH2sU0OCQA"
},
"source": [
"## --- Testing tokenization\n",
"## --- We apply the tokenizer to just one sentence to see the output\n",
"\n",
"# max_len = 0\n",
"# For every sentence...\n",
"# for sent in sentences:\n",
"# # Tokenize the text and add `[CLS]` and `[SEP]` tokens.\n",
"# try:\n",
"# # Floor Division : Gives only Fractional Part as Answer\n",
"# input_ids = tokenizer.encode(str(sent), add_special_tokens=True)\n",
"# # input_ids = tokenizer.encode(str(sent), add_special_tokens=True, max_length=512, truncation=True)\n",
"\n",
"# except Exception as e:\n",
"# # By this way we can know about the type of error occurring\n",
"# print(sent)\n",
"# print(\"The error is: \",e)\n",
"# break;\n",
"\n",
"# # Update the maximum sentence length.\n",
"# max_len = max(max_len, len(input_ids))\n",
"\n",
"# print('Max sentence length: ', max_len)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "2bBdb3pt8LuQ",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "26d415f3-e23e-4992-d1b2-dc6a57921375"
},
"source": [
"# Tokenize all of the sentences and map the tokens to thier word IDs.\n",
"input_ids = []\n",
"attention_masks = []\n",
"i = 0\n",
"# For every sentence...\n",
"for sent in sentences:\n",
" # `encode_plus` will:\n",
" # (1) Tokenize the sentence.\n",
" # (2) Prepend the `[CLS]` token to the start.\n",
" # (3) Append the `[SEP]` token to the end.\n",
" # (4) Map tokens to their IDs.\n",
" # (5) Pad or truncate the sentence to `max_length`\n",
" # (6) Create attention masks for [PAD] tokens.\n",
" encoded_dict = tokenizer.encode_plus(\n",
" str(sent), # Sentence to encode.\n",
" add_special_tokens = True, # Add '[CLS]' and '[SEP]'\n",
" max_length = 512, # Pad & truncate all sentences.\n",
" truncation=True,\n",
" pad_to_max_length = True,\n",
" return_attention_mask = True, # Construct attn. masks.\n",
" return_tensors = 'pt', # Return pytorch tensors.\n",
" )\n",
"\n",
" # Add the encoded sentence to the list.\n",
" input_ids.append(encoded_dict['input_ids'])\n",
"\n",
" # And its attention mask (simply differentiates padding from non-padding).\n",
" attention_masks.append(encoded_dict['attention_mask'])\n",
"\n",
"\n",
"# Convert the lists into tensors.\n",
"input_ids = torch.cat(input_ids, dim=0)\n",
"attention_masks = torch.cat(attention_masks, dim=0)\n",
"labels = torch.tensor(labels)\n",
"\n",
"# Print sentence 0, now as a list of IDs.\n",
"print('Original: ', sentences[0])\n",
"print('Token IDs:', input_ids[0])"
],
"execution_count": 59,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:2393: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n",
" warnings.warn(\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Original: you re not missing anything i was just being stupid sorry about that \n",
"Token IDs: tensor([ 101, 2017, 2128, 2025, 4394, 2505, 1045, 2001, 2074, 2108, 5236, 3374,\n",
" 2055, 2008, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0])\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aRp4O7D295d_"
},
"source": [
"#### Training & Validation Split\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qu0ao7p8rb06"
},
"source": [
"Divide up our training set to use 90% for training and 10% for validation."
]
},
{
"cell_type": "code",
"metadata": {
"id": "GEgLpFVlo1Z-",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "3c12bbc7-6a61-4023-f473-5990892b8739"
},
"source": [
"from torch.utils.data import TensorDataset, random_split\n",
"\n",
"# Combine the training inputs into a TensorDataset.\n",
"dataset = TensorDataset(input_ids, attention_masks, labels)\n",
"\n",
"# Create a 90-10 train-validation split.\n",
"\n",
"# Calculate the number of samples to include in each set.\n",
"train_size = int(0.98 * len(dataset))\n",
"val_size = len(dataset) - train_size\n",
"\n",
"# Divide the dataset by randomly selecting samples.\n",
"train_dataset, val_dataset = random_split(dataset, [train_size, val_size])\n",
"\n",
"print('{:>5,} training samples'.format(train_size))\n",
"print('{:>5,} validation samples'.format(val_size))"
],
"execution_count": 60,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"19,064 training samples\n",
" 390 validation samples\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dD9i6Z2pG-sN"
},
"source": [
"We'll also create an iterator for our dataset using the torch DataLoader class. This helps save on memory during training because, unlike a for loop, with an iterator the entire dataset does not need to be loaded into memory."
]
},
{
"cell_type": "code",
"metadata": {
"id": "XGUqOCtgqGhP"
},
"source": [
"from torch.utils.data import DataLoader, RandomSampler, SequentialSampler\n",
"\n",
"# The DataLoader needs to know our batch size for training, so we specify it\n",
"# here. For fine-tuning BERT on a specific task, the authors recommend a batch\n",
"# size of 16 or 32.\n",
"\n",
"# Create the DataLoaders for our training and validation sets.\n",
"# We'll take training samples in random order.\n",
"train_dataloader = DataLoader(\n",
" train_dataset, # The training samples.\n",
" sampler = RandomSampler(train_dataset), # Select batches randomly\n",
" batch_size = BATCH_SIZE # Trains with this batch size.\n",
" )\n",
"\n",
"# For validation the order doesn't matter, so we'll just read them sequentially.\n",
"validation_dataloader = DataLoader(\n",
" val_dataset, # The validation samples.\n",
" sampler = SequentialSampler(val_dataset), # Pull out batches sequentially.\n",
" batch_size = BATCH_SIZE # Evaluate with this batch size.\n",
" )"
],
"execution_count": 61,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from transformers import get_linear_schedule_with_warmup\n",
"\n",
"epochs = 8\n",
"\n",
"# Total number of training steps is [number of batches] x [number of epochs].\n",
"# (Note that this is not the same as the number of training samples).\n",
"total_steps = len(train_dataloader) * epochs\n",
"\n",
"# Create the learning rate scheduler.\n",
"scheduler = get_linear_schedule_with_warmup(optimizer,\n",
" num_warmup_steps = 0, # Default value in run_glue.py\n",
" num_training_steps = total_steps)\n"
],
"metadata": {
"id": "9fbomX_Oaqv8"
},
"execution_count": 62,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"#### Training Loop"
],
"metadata": {
"id": "_bTg5tTB9fhX"
}
},
{
"cell_type": "code",
"source": [
"training_stats = []\n",
"\n",
"def supervisedContrativeTraining(model, train_dataloader, epochs, device, optimizer, scheduler):\n",
"\n",
" #change here for using different loss function from https://kevinmusgrave.github.io/pytorch-metric-learning/\n",
" loss_function = loss_fun.SupConLoss(temperature=0.1,embedding_regularizer = LpRegularizer())\n",
"\n",
" # We'll store a number of quantities such as training and validation loss,\n",
" # validation accuracy, and timings.\n",
"\n",
" # Measure the total training time for the whole run.\n",
" total_t0 = time.time()\n",
"\n",
" # For each epoch...\n",
" for epoch in range(epochs):\n",
" # ========================================\n",
" # Training\n",
" # ========================================\n",
"\n",
" # Perform one full pass over the training set.\n",
"\n",
" print(\"\")\n",
" print('======== Epoch {:} / {:} ========'.format(epoch + 1, epochs))\n",
" print('Training...')\n",
"\n",
" # Measure how long the training epoch takes.\n",
" t0 = time.time()\n",
"\n",
" # Reset the total loss for this epoch.\n",
" total_train_loss = 0\n",
"\n",
" # Put the model into training mode. Don't be mislead--the call to\n",
" # `train` just changes the *mode*, it doesn't *perform* the training.\n",
" # `dropout` and `batchnorm` layers behave differently during training\n",
" # vs. test (source: https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch)\n",
" losses = []\n",
" model.zero_grad()\n",
" model.train()\n",
"\n",
" for step, batch in enumerate(train_dataloader):\n",
"\n",
" # Progress update every 40 batches.\n",
" if step % 40 == 0 and not step == 0:\n",
" # Calculate elapsed time in minutes.\n",
" elapsed = format_time(time.time() - t0)\n",
"\n",
" # Report progress.\n",
" print(' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))\n",
"\n",
" # for d in train_dataloader:\n",
" input_ids = batch[0].to(device)\n",
" attention_mask = batch[1].to(device)\n",
" targets = batch[2].to(device)\n",
"\n",
" # input_ids = d[\"input_ids\"].reshape(32, 160).to(device)\n",
" # attention_mask = d[\"attention_mask\"].to(device)\n",
" # targets = d[\"targets\"].to(device)\n",
"\n",
" outputs = model(input_ids=input_ids, token_type_ids=None, attention_mask=attention_mask)\n",
"\n",
" # hidden_states = outputs[0]\n",
" hidden_states = outputs.hidden_states\n",
" # print(hidden_states)\n",
" # supcon_fea_cls = F.normalize(hidden_states[:,0,:],dim=1)\n",
" # supcon_fea_cls = F.normalize(hidden_states[::1],dim=1)\n",
" supcon_fea_cls = F.normalize(hidden_states[-1][:,0,:],dim=1)\n",
"\n",
" loss = loss_function(supcon_fea_cls, targets)\n",
"\n",
" # loss = outputs.loss\n",
" # logits = outputs.logits\n",
"\n",
"\n",
" if not torch.isnan(loss):\n",
" losses.append(loss.item())\n",
"\n",
" # Accumulate the training loss over all of the batches so that we can\n",
" # calculate the average loss at the end. `loss` is a Tensor containing a\n",
" # single value; the `.item()` function just returns the Python value\n",
" # from the tensor.\n",
" total_train_loss += loss.item()\n",
" # print(loss)\n",
"\n",
" # Perform a backward pass to calculate the gradients.\n",
" loss.backward()\n",
"\n",
" # Clip the norm of the gradients to 1.0.\n",
" # This is to help prevent the \"exploding gradients\" problem.\n",
" nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
"\n",
" # Update parameters and take a step using the computed gradient.\n",
" # The optimizer dictates the \"update rule\"--how the parameters are\n",
" # modified based on their gradients, the learning rate, etc.\n",
" optimizer.step()\n",
"\n",
" # Update the learning rate.\n",
" scheduler.step()\n",
" optimizer.zero_grad()\n",
"\n",
" print('Contrastive Loss Mean: ', np.mean(losses))\n",
"\n",
" # Calculate the average loss over all of the batches.\n",
" avg_train_loss = total_train_loss / len(train_dataloader)\n",
"\n",
" # Measure how long this epoch took.\n",
" training_time = format_time(time.time() - t0)\n",
" training_stats.append(\n",
" {\n",
" 'epoch': epoch + 1,\n",
" 'Training Loss': avg_train_loss,\n",
" # 'Valid. Loss': avg_val_loss,\n",
" # 'Valid. Accur.': avg_val_accuracy,\n",
" 'Training Time': training_time,\n",
" # 'Validation Time': validation_time\n",
" }\n",
" )\n",
"\n",
" print(\"\")\n",
" print(\" Average training loss: {0:.2f}\".format(avg_train_loss))\n",
" print(\" Training epcoh took: {:}\".format(training_time))\n",
"\n",
"\n",
" print(\"\")\n",
" print(\"Training complete!\")\n",
"\n",
" print(\"Total training took {:} (h:mm:ss)\".format(format_time(time.time()- total_t0)))\n"
],
"metadata": {
"id": "jrcelmju9iGj"
},
"execution_count": 63,
"outputs": []
},
{
"cell_type": "code",
"source": [
"supervisedContrativeTraining(model, train_dataloader, epochs, device, optimizer, scheduler)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Iqr6ufbG84ZO",
"outputId": "05ab974e-9a4d-47bc-e946-2e544d8fff7a"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"======== Epoch 1 / 8 ========\n",
"Training...\n",
" Batch 40 of 1,192. Elapsed: 0:01:00.\n",
" Batch 80 of 1,192. Elapsed: 0:01:56.\n",
" Batch 120 of 1,192. Elapsed: 0:02:51.\n",
" Batch 160 of 1,192. Elapsed: 0:03:47.\n",
" Batch 200 of 1,192. Elapsed: 0:04:43.\n",
" Batch 240 of 1,192. Elapsed: 0:05:38.\n",
" Batch 280 of 1,192. Elapsed: 0:06:34.\n",
" Batch 320 of 1,192. Elapsed: 0:07:30.\n",
" Batch 360 of 1,192. Elapsed: 0:08:26.\n",
" Batch 400 of 1,192. Elapsed: 0:09:21.\n",
" Batch 440 of 1,192. Elapsed: 0:10:17.\n",
" Batch 480 of 1,192. Elapsed: 0:11:13.\n",
" Batch 520 of 1,192. Elapsed: 0:12:08.\n",
" Batch 560 of 1,192. Elapsed: 0:13:04.\n",
" Batch 600 of 1,192. Elapsed: 0:14:00.\n",
" Batch 640 of 1,192. Elapsed: 0:14:56.\n",
" Batch 680 of 1,192. Elapsed: 0:15:52.\n",
" Batch 720 of 1,192. Elapsed: 0:16:48.\n",
" Batch 760 of 1,192. Elapsed: 0:17:44.\n",
" Batch 800 of 1,192. Elapsed: 0:18:40.\n",
" Batch 840 of 1,192. Elapsed: 0:19:36.\n",
" Batch 880 of 1,192. Elapsed: 0:20:32.\n",
" Batch 920 of 1,192. Elapsed: 0:21:27.\n",
" Batch 960 of 1,192. Elapsed: 0:22:24.\n",
" Batch 1,000 of 1,192. Elapsed: 0:23:20.\n",
" Batch 1,040 of 1,192. Elapsed: 0:24:15.\n",
" Batch 1,080 of 1,192. Elapsed: 0:25:11.\n",
" Batch 1,120 of 1,192. Elapsed: 0:26:07.\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import gc\n",
"# del variables\n",
"gc.collect()\n",
"\n",
"torch.cuda.memory_summary(device=None, abbreviated=False)\n",
"torch.cuda.empty_cache()"
],
"metadata": {
"id": "ceqbZU1XhOQk"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "RqfmWwUR_Sox"
},
"source": [
"#### Training Results"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VQTvJ1vRP7u4"
},
"source": [
"Let's view the summary of the training process."
]
},
{
"cell_type": "code",
"metadata": {
"id": "6O_NbXFGMukX",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 244
},
"outputId": "42013514-c334-4aa1-d165-e37f3bf4cdd1"
},
"source": [
"import pandas as pd\n",
"\n",
"# Display floats with two decimal places.\n",
"# pd.set_option('precision', 2)\n",
"pd.options.display.max_rows = 5\n",
"# pd.options.precision = 2\n",
"print(training_stats)\n",
"# Create a DataFrame from our training statistics.\n",
"df_stats = pd.DataFrame(data=training_stats)\n",
"\n",
"# Use the 'epoch' as the row index.\n",
"df_stats = df_stats.set_index('epoch')\n",
"\n",
"# A hack to force the column headers to wrap.\n",
"#df = df.style.set_table_styles([dict(selector=\"th\",props=[('max-width', '70px')])])\n",
"\n",
"# Display the table.\n",
"df_stats"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[{'epoch': 1, 'Training Loss': 3.426470624450037, 'Training Time': '0:29:18'}, {'epoch': 2, 'Training Loss': 3.328372689180596, 'Training Time': '0:29:33'}, {'epoch': 3, 'Training Loss': 3.2352232127094585, 'Training Time': '0:29:34'}, {'epoch': 4, 'Training Loss': 3.1862575237537145, 'Training Time': '0:29:42'}]\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Training Loss Training Time\n",
"epoch \n",
"1 3.426471 0:29:18\n",
"2 3.328373 0:29:33\n",
"3 3.235223 0:29:34\n",
"4 3.186258 0:29:42"
],
"text/html": [
"\n",
"\n",
" <div id=\"df-28ab133e-970f-4b16-93c0-f537c437cf39\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Training Loss</th>\n",
" <th>Training Time</th>\n",
" </tr>\n",
" <tr>\n",
" <th>epoch</th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>3.426471</td>\n",
" <td>0:29:18</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3.328373</td>\n",
" <td>0:29:33</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>3.235223</td>\n",
" <td>0:29:34</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>3.186258</td>\n",
" <td>0:29:42</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-28ab133e-970f-4b16-93c0-f537c437cf39')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
"\n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
"\n",
"\n",
"\n",
" <div id=\"df-dd6eab2d-3d43-40f4-b8c5-3f6ed52dad37\">\n",
" <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-dd6eab2d-3d43-40f4-b8c5-3f6ed52dad37')\"\n",
" title=\"Suggest charts.\"\n",
" style=\"display:none;\">\n",
"\n",
"<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <g>\n",
" <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
" </g>\n",
"</svg>\n",
" </button>\n",
" </div>\n",
"\n",
"<style>\n",
" .colab-df-quickchart {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-quickchart:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-quickchart {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-quickchart:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
"</style>\n",
"\n",
" <script>\n",
" async function quickchart(key) {\n",
" const containerElement = document.querySelector('#' + key);\n",
" const charts = await google.colab.kernel.invokeFunction(\n",
" 'suggestCharts', [key], {});\n",
" }\n",
" </script>\n",
"\n",
" <script>\n",
"\n",
"function displayQuickchartButton(domScope) {\n",
" let quickchartButtonEl =\n",
" domScope.querySelector('#df-dd6eab2d-3d43-40f4-b8c5-3f6ed52dad37 button.colab-df-quickchart');\n",
" quickchartButtonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"}\n",
"\n",
" displayQuickchartButton(document);\n",
" </script>\n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-28ab133e-970f-4b16-93c0-f537c437cf39 button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-28ab133e-970f-4b16-93c0-f537c437cf39');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n"
]
},
"metadata": {},
"execution_count": 29
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1-G03mmwH3aI"
},
"source": [
"Notice that, while the the training loss is going down with each epoch, the validation loss is increasing! This suggests that we are training our model too long, and it's over-fitting on the training data.\n",
"\n",
"(For reference, we are using 7,695 training samples and 856 validation samples).\n",
"\n",
"Validation Loss is a more precise measure than accuracy, because with accuracy we don't care about the exact output value, but just which side of a threshold it falls on.\n",
"\n",
"If we are predicting the correct answer, but with less confidence, then validation loss will catch this, while accuracy will not."
]
},
{
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt"
],
"metadata": {
"id": "I8B0VgTYzJLz"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"%matplotlib inline"
],
"metadata": {
"id": "6bFcIFLtzLmi"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "68xreA9JAmG5",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 431
},
"outputId": "69343855-725d-4a5f-e30f-81aaf59dc044"
},
"source": [
"import seaborn as sns\n",
"\n",
"# Use plot styling from seaborn.\n",
"sns.set(style='darkgrid')\n",
"\n",
"# Increase the plot size and font size.\n",
"sns.set(font_scale=1.5)\n",
"plt.rcParams[\"figure.figsize\"] = (12,6)\n",
"\n",
"# Plot the learning curve.\n",
"plt.plot(df_stats['Training Loss'], 'b-o', label=\"Training\")\n",
"# plt.plot(df_stats['Valid. Loss'], 'g-o', label=\"Validation\")\n",
"\n",
"# Label the plot.\n",
"plt.title(\"Training Loss - COntrastive learning\")\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.legend()\n",
"# plt.xticks([1, 2, 3, 4])\n",
"plt.xticks(list(range(1, epochs)))\n",
"\n",
"plt.show()"
],
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1200x600 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Tg42jJqqM68F"
},
"source": [
"#### Test Data Preparation\n",
"\n",
"We'll need to apply all of the same steps that we did for the training data to prepare our test data set."
]
},
{
"cell_type": "code",
"source": [
"## Preprocess test data\n",
"test_df['modified_message'] = test_df.astype(str).apply(lambda row : process_text(row['message']), axis = 1)\n",
"print(test_df.sample(10))\n",
"\n",
"test_sentences = test_df.modified_message.values\n",
"test_labels = test_df.is_toxic.values\n"
],
"metadata": {
"id": "eDuXSdtmfT-i"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "mAN0LZBOOPVh",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "b026c478-92d6-4917-821d-654bb9eee71b"
},
"source": [
"import pandas as pd\n",
"\n",
"# Load the test dataset into a pandas dataframe.\n",
"# df = pd.read_csv(\"./test_dataset.csv\", delimiter='\\t', header=None, names=['sentence', 'is_toxic'])\n",
"\n",
"# Report the number of sentences.\n",
"print('Number of test sentences: {:,}\\n'.format(test_df.shape[0]))\n",
"\n",
"# Create sentence and label lists\n",
"# sentences = test_df.sentence.values\n",
"# labels = test_df.label.values\n",
"\n",
"# Tokenize all of the sentences and map the tokens to thier word IDs.\n",
"input_ids = []\n",
"attention_masks = []\n",
"\n",
"# For every sentence...\n",
"for sent in test_sentences:\n",
" # `encode_plus` will:\n",
" # (1) Tokenize the sentence.\n",
" # (2) Prepend the `[CLS]` token to the start.\n",
" # (3) Append the `[SEP]` token to the end.\n",
" # (4) Map tokens to their IDs.\n",
" # (5) Pad or truncate the sentence to `max_length`\n",
" # (6) Create attention masks for [PAD] tokens.\n",
" encoded_dict = tokenizer.encode_plus(\n",
" str(sent), # Sentence to encode.\n",
" add_special_tokens = True, # Add '[CLS]' and '[SEP]'\n",
" max_length = 512, # Pad & truncate all sentences.\n",
" pad_to_max_length = True,\n",
" return_attention_mask = True, # Construct attn. masks.\n",
" return_tensors = 'pt', # Return pytorch tensors.\n",
" )\n",
"\n",
" # Add the encoded sentence to the list.\n",
" input_ids.append(encoded_dict['input_ids'])\n",
"\n",
" # And its attention mask (simply differentiates padding from non-padding).\n",
" attention_masks.append(encoded_dict['attention_mask'])\n",
"\n",
"# Convert the lists into tensors.\n",
"input_ids = torch.cat(input_ids, dim=0)\n",
"attention_masks = torch.cat(attention_masks, dim=0)\n",
"test_labels = torch.tensor(test_labels)\n",
"\n",
"# Set the batch size.\n",
"batch_size = 16\n",
"\n",
"# Create the DataLoader.\n",
"prediction_data = TensorDataset(input_ids, attention_masks, test_labels)\n",
"prediction_sampler = SequentialSampler(prediction_data)\n",
"prediction_dataloader = DataLoader(prediction_data, sampler=prediction_sampler, batch_size=batch_size)\n",
"\n",
"print('Predicting labels for {:,} test sentences...'.format(len(input_ids)))\n"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Predicting labels for 15 test sentences...\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:2393: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n",
" warnings.warn(\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "16lctEOyNFik"
},
"source": [
"#### Evaluate on Test Set\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rhR99IISNMg9"
},
"source": [
"\n",
"With the test set prepared, we can apply our fine-tuned model to generate predictions on the test set."
]
},
{
"cell_type": "code",
"metadata": {
"id": "Hba10sXR7Xi6",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "32cfa74f-f116-43b7-bb83-71da0fbf537f"
},
"source": [
"# Prediction on test set\n",
"print(\"Running Test set...\")\n",
"\n",
"t0 = time.time()\n",
"\n",
"# Put the model in evaluation mode--the dropout layers behave differently\n",
"# during evaluation.\n",
"model.eval()\n",
"\n",
"# Tracking variables\n",
"total_eval_accuracy = 0\n",
"total_eval_loss = 0\n",
"nb_eval_steps = 0\n",
"\n",
"\n",
"# Put model in evaluation mode\n",
"model.eval()\n",
"\n",
"# Tracking variables\n",
"predictions , true_labels = [], []\n",
"\n",
"# Predict\n",
"for batch in prediction_dataloader:\n",
" # Add batch to GPU\n",
" batch = tuple(t.to(device) for t in batch)\n",
"\n",
" # Unpack the inputs from our dataloader\n",
" b_input_ids, b_input_mask, b_labels = batch\n",
"\n",
" # Telling the model not to compute or store gradients, saving memory and\n",
" # speeding up prediction\n",
" with torch.no_grad():\n",
" # Forward pass, calculate logit predictions.\n",
" # result = model(b_input_ids,\n",
" # token_type_ids=None,\n",
" # attention_mask=b_input_mask,\n",
" # labels=b_labels,\n",
" # return_dict=True)\n",
"\n",
" result = model(b_input_ids,\n",
" token_type_ids=None,\n",
" attention_mask=b_input_mask,\n",
" return_dict=True)\n",
"\n",
" # Get the loss and \"logits\" output by the model. The \"logits\" are the\n",
" # output values prior to applying an activation function like the\n",
" # softmax.\n",
" # loss = result.loss\n",
" logits = result.logits\n",
"\n",
" # Accumulate the validation loss.\n",
" # total_eval_loss += loss.item()\n",
"\n",
" # Move logits and labels to CPU\n",
" logits = logits.detach().cpu().numpy()\n",
" label_ids = b_labels.to('cpu').numpy()\n",
"\n",
" # Store predictions and true labels\n",
" predictions.append(logits)\n",
" true_labels.append(label_ids)\n",
"\n",
" # Calculate the accuracy for this batch of test sentences, and\n",
" # accumulate it over all batches.\n",
" total_eval_accuracy += flat_accuracy(logits, label_ids)\n",
"\n",
"\n",
"print(' DONE.')\n",
"# Report the final accuracy for this run.\n",
"avg_val_accuracy = total_eval_accuracy / len(prediction_dataloader) * 100.0\n",
"print(\" Accuracy: {0:.2f}\".format(avg_val_accuracy))\n"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Running Test set...\n",
" DONE.\n",
" Accuracy: 73.33\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "cRaZQ4XC7kLs",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "00271fac-c8bd-4c5d-a185-a51a738d5d0a"
},
"source": [
"from sklearn.metrics import matthews_corrcoef\n",
"\n",
"matthews_set = []\n",
"\n",
"# Evaluate each test batch using Matthew's correlation coefficient\n",
"print('Calculating Matthews Corr. Coef. for each batch...')\n",
"final_labels = []\n",
"# For each input batch...\n",
"for i in range(len(true_labels)):\n",
"\n",
" # The predictions for this batch are a 2-column ndarray (one column for \"0\"\n",
" # and one column for \"1\"). Pick the label with the highest value and turn this\n",
" # in to a list of 0s and 1s.\n",
" pred_labels_i = np.argmax(predictions[i], axis=1).flatten()\n",
" final_labels.append(pred_labels_i)\n",
" # Calculate and store the coef for this batch.\n",
" matthews = matthews_corrcoef(true_labels[i], pred_labels_i)\n",
" matthews_set.append(matthews)\n",
"\n",
"print(final_labels)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Calculating Matthews Corr. Coef. for each batch...\n",
"[array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])]\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"#### Calculate Accuracy & F1 Score"
],
"metadata": {
"id": "ORnn8VPz9uLC"
}
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"from sklearn.metrics import f1_score\n",
"\n",
"count = 0\n",
"# Calculate accuracy\n",
"for i in range(len(true_labels[0])):\n",
" if true_labels[0][i] == final_labels[0][i]:\n",
" count = count + 1\n",
"\n",
"\n",
"print(\" Accuracy: {0:.2f}\".format(count/len(true_labels[0]) * 100.0))\n",
"\n",
"\n",
"def f1(actual, predicted, label):\n",
"\n",
" \"\"\" A helper function to calculate f1-score for the given `label` \"\"\"\n",
" # F1 = 2 * (precision * recall) / (precision + recall)\n",
" tp = np.sum((actual==label) & (predicted==label))\n",
" fp = np.sum((actual!=label) & (predicted==label))\n",
" fn = np.sum((predicted!=label) & (actual==label))\n",
"\n",
" precision = tp/(tp+fp)\n",
" recall = tp/(tp+fn)\n",
" f1 = 2 * (precision * recall) / (precision + recall)\n",
" return f1\n",
"\n",
"def f1_macro(actual, predicted):\n",
" # `macro` f1- unweighted mean of f1 per label\n",
" return np.mean([f1(actual, predicted, label)\n",
" for label in np.unique(actual)])\n",
"\n",
"\n",
"print(\" F1 Score: {0:.2f}\".format(f1_macro(true_labels[0], final_labels[0])*100.0))\n",
"\n",
"#define array of actual classes\n",
"# actual = np.repeat(true_labels, repeats=[160, 240])\n",
"\n",
"#define array of predicted classes\n",
"# pred = np.repeat(final_labels, repeats=[120, 40, 70, 170])\n",
"\n",
"#calculate F1 score\n",
"# f1_score(actual, pred)"
],
"metadata": {
"id": "_0dxa1JW9roc",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "2cfa6e26-69a9-44be-a39e-f9bd25d06b37"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" Accuracy: 73.33\n",
" F1 Score: nan\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"<ipython-input-36-5caed28a9aa0>:22: RuntimeWarning: invalid value encountered in long_scalars\n",
" precision = tp/(tp+fp)\n"
]
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "oCYZa1lQ8Jn8",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "d4b4b8ec-3fb9-45a4-d4cd-cd7fc579daab"
},
"source": [
"# Combine the results across all batches.\n",
"flat_predictions = np.concatenate(predictions, axis=0)\n",
"\n",
"# For each sample, pick the label (0 or 1) with the higher score.\n",
"flat_predictions = np.argmax(flat_predictions, axis=1).flatten()\n",
"\n",
"# Combine the correct labels for each batch into a single list.\n",
"flat_true_labels = np.concatenate(true_labels, axis=0)\n",
"\n",
"# Calculate the MCC\n",
"mcc = matthews_corrcoef(flat_true_labels, flat_predictions)\n",
"\n",
"print('Total MCC: %.3f' % mcc)"
],
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Total MCC: 0.000\n"
]
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xlQG7qgkmf4n"
},
"source": [
"This post demonstrates that with a pre-trained BERT model you can quickly and effectively create a high quality model with minimal effort and training time using the pytorch interface, regardless of the specific NLP task you are interested in."
]
},
{
"cell_type": "markdown",
"source": [
"#### Saving New model"
],
"metadata": {
"id": "eggfbrc6-WQg"
}
},
{
"cell_type": "code",
"source": [
"import os\n",
"\n",
"# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()\n",
"\n",
"output_dir = './preprocess_contrastive_model-2/'\n",
"\n",
"# Create output directory if needed\n",
"if not os.path.exists(output_dir):\n",
" os.makedirs(output_dir)\n",
"\n",
"print(\"Saving model to %s\" % output_dir)\n",
"\n",
"# Save a trained model, configuration and tokenizer using `save_pretrained()`.\n",
"# They can then be reloaded using `from_pretrained()`\n",
"model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training\n",
"model_to_save.save_pretrained(output_dir)\n",
"tokenizer.save_pretrained(output_dir)\n",
"\n",
"# Good practice: save your training arguments together with the trained model\n",
"# torch.save(args, os.path.join(output_dir, 'training_args.bin'))\n"
],
"metadata": {
"id": "wE8Sc7N1-ZP5",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "a475a6ce-f694-4571-9671-8cf9660bb2d9"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Saving model to ./contrastive_model/\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"('./contrastive_model/tokenizer_config.json',\n",
" './contrastive_model/special_tokens_map.json',\n",
" './contrastive_model/vocab.txt',\n",
" './contrastive_model/added_tokens.json')"
]
},
"metadata": {},
"execution_count": 38
}
]
},
{
"cell_type": "code",
"source": [
"# Mount Google Drive to this Notebook instance.\n",
"from google.colab import drive\n",
"\n",
"drive.mount('/content/drive')"
],
"metadata": {
"id": "apZeu_nr-oys",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "26841f4b-924c-4442-9347-fa84fceaeaa3"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Mounted at /content/drive\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Copy the model files to a directory in your Google Drive.\n",
"!cp -r ./preprocess_contrastive_model-2/ ./drive/MyDrive"
],
"metadata": {
"id": "lF10V-Dw-p9s"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {
"id": "NIWouvDrGVAi"
},
"source": [
"## Cros-Entropy Loss Fine tuning\n",
"\n"
]
},
{
"cell_type": "code",
"source": [
"# Load model\n",
"# Mount Google Drive to this Notebook instance.\n",
"from google.colab import drive\n",
"\n",
"drive.mount('/content/drive')\n",
"\n",
"!cp -r ./drive/MyDrive/preprocess_contrastive_model-2/ ."
],
"metadata": {
"id": "9kCmjivJIvay",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "3a988a43-65dd-4e07-d497-f48a8feaf816"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"#### Model Initialize and Tokenize"
],
"metadata": {
"id": "hD8eBR1axG9X"
}
},
{
"cell_type": "code",
"source": [
"output_dir = './preprocess_contrastive_model-2/'\n",
"# Load a trained model and vocabulary that you have fine-tuned\n",
"from transformers import BertTokenizer\n",
"from transformers import BertForSequenceClassification, AdamW, BertConfig\n",
"\n",
"# Load the BERT tokenizer.\n",
"print('Loading BERT tokenizer...')\n",
"# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)\n",
"tokenizer = BertTokenizer.from_pretrained(output_dir, num_labels = 2)\n",
"\n",
"# Load BertForSequenceClassification, the pretrained BERT model with a single\n",
"# linear classification layer on top.\n",
"model = BertForSequenceClassification.from_pretrained(output_dir)\n",
"\n",
"\n",
"# model = model_class.from_pretrained(output_dir)\n",
"# tokenizer = tokenizer_class.from_pretrained(output_dir)\n",
"\n",
"# Copy the model to the GPU.\n",
"model.to(device)\n",
"\n",
"# Get all of the model's parameters as a list of tuples.\n",
"params = list(model.named_parameters())\n",
"\n",
"print('The BERT model has {:} different named parameters.\\n'.format(len(params)))\n",
"\n",
"print('==== Embedding Layer ====\\n')\n",
"\n",
"for p in params[0:5]:\n",
" print(\"{:<55} {:>12}\".format(p[0], str(tuple(p[1].size()))))\n",
"\n",
"print('\\n==== First Transformer ====\\n')\n",
"\n",
"for p in params[5:21]:\n",
" print(\"{:<55} {:>12}\".format(p[0], str(tuple(p[1].size()))))\n",
"\n",
"print('\\n==== Output Layer ====\\n')\n",
"\n",
"for p in params[-4:]:\n",
" print(\"{:<55} {:>12}\".format(p[0], str(tuple(p[1].size()))))"
],
"metadata": {
"id": "d6_nrom8I7AE",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "512a7e36-a588-409a-832d-3a76b87a7318"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Loading BERT tokenizer...\n",
"The BERT model has 201 different named parameters.\n",
"\n",
"==== Embedding Layer ====\n",
"\n",
"bert.embeddings.word_embeddings.weight (30522, 768)\n",
"bert.embeddings.position_embeddings.weight (512, 768)\n",
"bert.embeddings.token_type_embeddings.weight (2, 768)\n",
"bert.embeddings.LayerNorm.weight (768,)\n",
"bert.embeddings.LayerNorm.bias (768,)\n",
"\n",
"==== First Transformer ====\n",
"\n",
"bert.encoder.layer.0.attention.self.query.weight (768, 768)\n",
"bert.encoder.layer.0.attention.self.query.bias (768,)\n",
"bert.encoder.layer.0.attention.self.key.weight (768, 768)\n",
"bert.encoder.layer.0.attention.self.key.bias (768,)\n",
"bert.encoder.layer.0.attention.self.value.weight (768, 768)\n",
"bert.encoder.layer.0.attention.self.value.bias (768,)\n",
"bert.encoder.layer.0.attention.output.dense.weight (768, 768)\n",
"bert.encoder.layer.0.attention.output.dense.bias (768,)\n",
"bert.encoder.layer.0.attention.output.LayerNorm.weight (768,)\n",
"bert.encoder.layer.0.attention.output.LayerNorm.bias (768,)\n",
"bert.encoder.layer.0.intermediate.dense.weight (3072, 768)\n",
"bert.encoder.layer.0.intermediate.dense.bias (3072,)\n",
"bert.encoder.layer.0.output.dense.weight (768, 3072)\n",
"bert.encoder.layer.0.output.dense.bias (768,)\n",
"bert.encoder.layer.0.output.LayerNorm.weight (768,)\n",
"bert.encoder.layer.0.output.LayerNorm.bias (768,)\n",
"\n",
"==== Output Layer ====\n",
"\n",
"bert.pooler.dense.weight (768, 768)\n",
"bert.pooler.dense.bias (768,)\n",
"classifier.weight (2, 768)\n",
"classifier.bias (2,)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# This code is taken from:\n",
"# https://github.com/huggingface/transformers/blob/5bfcd0485ece086ebcbed2d008813037968a9e58/examples/run_glue.py#L102\n",
"\n",
"# Don't apply weight decay to any parameters whose names include these tokens.\n",
"# (Here, the BERT doesn't have `gamma` or `beta` parameters, only `bias` terms)\n",
"# no_decay = ['bias', 'LayerNorm.weight']\n",
"\n",
"# # Separate the `weight` parameters from the `bias` parameters.\n",
"# # - For the `weight` parameters, this specifies a 'weight_decay_rate' of 0.01.\n",
"# # - For the `bias` parameters, the 'weight_decay_rate' is 0.0.\n",
"# optimizer_grouped_parameters = [\n",
"# # Filter for all parameters which *don't* include 'bias', 'gamma', 'beta'.\n",
"# {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],\n",
"# 'weight_decay_rate': 0.1},\n",
"\n",
"# # Filter for parameters which *do* include those.\n",
"# {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],\n",
"# 'weight_decay_rate': 0.0}\n",
"# ]\n",
"\n",
"# Note - `optimizer_grouped_parameters` only includes the parameter values, not\n",
"# the names.\n",
"\n",
"# Note: AdamW is a class from the huggingface library (as opposed to pytorch)\n",
"# I believe the 'W' stands for 'Weight Decay fix\"\n",
"from transformers import get_linear_schedule_with_warmup\n",
"\n",
"optimizer = AdamW(model.parameters(),\n",
" lr = 2e-5, # args.learning_rate - default is 5e-5, our notebook had 2e-5\n",
" eps = 1e-8 # args.adam_epsilon - default is 1e-8.\n",
" )\n"
],
"metadata": {
"id": "X-M2UrjUf3_9"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Tokenize all of the sentences and map the tokens to thier word IDs.\n",
"input_ids = []\n",
"attention_masks = []\n",
"i = 0\n",
"# For every sentence...\n",
"for sent in sentences:\n",
" # `encode_plus` will:\n",
" # (1) Tokenize the sentence.\n",
" # (2) Prepend the `[CLS]` token to the start.\n",
" # (3) Append the `[SEP]` token to the end.\n",
" # (4) Map tokens to their IDs.\n",
" # (5) Pad or truncate the sentence to `max_length`\n",
" # (6) Create attention masks for [PAD] tokens.\n",
" encoded_dict = tokenizer.encode_plus(\n",
" str(sent), # Sentence to encode.\n",
" add_special_tokens = True, # Add '[CLS]' and '[SEP]'\n",
" max_length = 512, # Pad & truncate all sentences.\n",
" truncation=True,\n",
" pad_to_max_length = True,\n",
" return_attention_mask = True, # Construct attn. masks.\n",
" return_tensors = 'pt', # Return pytorch tensors.\n",
" )\n",
"\n",
" # Add the encoded sentence to the list.\n",
" input_ids.append(encoded_dict['input_ids'])\n",
"\n",
" # And its attention mask (simply differentiates padding from non-padding).\n",
" attention_masks.append(encoded_dict['attention_mask'])\n",
"\n",
"\n",
"# Convert the lists into tensors.\n",
"input_ids = torch.cat(input_ids, dim=0)\n",
"attention_masks = torch.cat(attention_masks, dim=0)\n",
"labels = torch.tensor(labels)\n",
"\n",
"# Print sentence 0, now as a list of IDs.\n",
"print('Original: ', sentences[0])\n",
"print('Token IDs:', input_ids[0])"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fQvnQxusf2TJ",
"outputId": "169a4d41-a9cc-4103-cc74-3c45184e93c4"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:2393: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n",
" warnings.warn(\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Original: this and below assignments also should be removed\n",
"Token IDs: tensor([ 101, 2023, 1998, 2917, 14799, 2036, 2323, 2022, 3718, 102,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0])\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from torch.utils.data import TensorDataset, random_split\n",
"\n",
"# Combine the training inputs into a TensorDataset.\n",
"dataset = TensorDataset(input_ids, attention_masks, labels)\n",
"\n",
"# Create a 90-10 train-validation split.\n",
"\n",
"# Calculate the number of samples to include in each set.\n",
"train_size = int(0.98 * len(dataset))\n",
"val_size = len(dataset) - train_size\n",
"\n",
"# Divide the dataset by randomly selecting samples.\n",
"train_dataset, val_dataset = random_split(dataset, [train_size, val_size])\n",
"\n",
"print('{:>5,} training samples'.format(train_size))\n",
"print('{:>5,} validation samples'.format(val_size))"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GgXH77OkgMLi",
"outputId": "4a6735aa-6ef9-4bc3-b643-ee540f9106de"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"19,257 training samples\n",
" 394 validation samples\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from torch.utils.data import DataLoader, RandomSampler, SequentialSampler\n",
"\n",
"# The DataLoader needs to know our batch size for training, so we specify it\n",
"# here. For fine-tuning BERT on a specific task, the authors recommend a batch\n",
"# size of 16 or 32.\n",
"\n",
"# Create the DataLoaders for our training and validation sets.\n",
"# We'll take training samples in random order.\n",
"train_dataloader = DataLoader(\n",
" train_dataset, # The training samples.\n",
" sampler = RandomSampler(train_dataset), # Select batches randomly\n",
" batch_size = BATCH_SIZE # Trains with this batch size.\n",
" )\n",
"\n",
"# For validation the order doesn't matter, so we'll just read them sequentially.\n",
"validation_dataloader = DataLoader(\n",
" val_dataset, # The validation samples.\n",
" sampler = SequentialSampler(val_dataset), # Pull out batches sequentially.\n",
" batch_size = BATCH_SIZE # Evaluate with this batch size.\n",
" )"
],
"metadata": {
"id": "yXzRGzUogaDb"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "QxSMw0FrptiL"
},
"source": [
"# Number of training epochs. The BERT authors recommend between 2 and 4.\n",
"# We chose to run for 4, but we'll see later that this may be over-fitting the\n",
"# training data.\n",
"epochs = 4\n",
"\n",
"# Total number of training steps is [number of batches] x [number of epochs].\n",
"# (Note that this is not the same as the number of training samples).\n",
"total_steps = len(train_dataloader) * epochs\n",
"\n",
"# Create the learning rate scheduler.\n",
"scheduler = get_linear_schedule_with_warmup(optimizer,\n",
" num_warmup_steps = 0, # Default value in run_glue.py\n",
" num_training_steps = total_steps)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"#### Cross-entropy training loop"
],
"metadata": {
"id": "O5Eeo6YKLMmX"
}
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"import time\n",
"import datetime\n",
"\n",
"# Function to calculate the accuracy of our predictions vs labels\n",
"def flat_accuracy(preds, labels):\n",
" pred_flat = np.argmax(preds, axis=1).flatten()\n",
" labels_flat = labels.flatten()\n",
" return np.sum(pred_flat == labels_flat) / len(labels_flat)\n",
"\n",
"\n",
"\n",
"def format_time(elapsed):\n",
" '''\n",
" Takes a time in seconds and returns a string hh:mm:ss\n",
" '''\n",
" # Round to the nearest second.\n",
" elapsed_rounded = int(round((elapsed)))\n",
"\n",
" # Format as hh:mm:ss\n",
" return str(datetime.timedelta(seconds=elapsed_rounded))\n"
],
"metadata": {
"id": "owrPlWvnLQxP"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import random\n",
"import numpy as np\n",
"\n",
"# This training code is based on the `run_glue.py` script here:\n",
"# https://github.com/huggingface/transformers/blob/5bfcd0485ece086ebcbed2d008813037968a9e58/examples/run_glue.py#L128\n",
"\n",
"# Set the seed value all over the place to make this reproducible.\n",
"seed_val = 42\n",
"\n",
"random.seed(seed_val)\n",
"np.random.seed(seed_val)\n",
"torch.manual_seed(seed_val)\n",
"torch.cuda.manual_seed_all(seed_val)\n",
"\n",
"# We'll store a number of quantities such as training and validation loss,\n",
"# validation accuracy, and timings.\n",
"training_stats = []\n",
"\n",
"# Measure the total training time for the whole run.\n",
"total_t0 = time.time()\n",
"\n",
"# For each epoch...\n",
"for epoch_i in range(0, epochs):\n",
"\n",
" # ========================================\n",
" # Training\n",
" # ========================================\n",
"\n",
" # Perform one full pass over the training set.\n",
"\n",
" print(\"\")\n",
" print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))\n",
" print('Training...')\n",
"\n",
" # Measure how long the training epoch takes.\n",
" t0 = time.time()\n",
"\n",
" # Reset the total loss for this epoch.\n",
" total_train_loss = 0\n",
"\n",
" # Put the model into training mode. Don't be mislead--the call to\n",
" # `train` just changes the *mode*, it doesn't *perform* the training.\n",
" # `dropout` and `batchnorm` layers behave differently during training\n",
" # vs. test (source: https://stackoverflow.com/questions/51433378/what-does-model-train-do-in-pytorch)\n",
" model.train()\n",
"\n",
" # For each batch of training data...\n",
" for step, batch in enumerate(train_dataloader):\n",
"\n",
" # Progress update every 40 batches.\n",
" if step % 40 == 0 and not step == 0:\n",
" # Calculate elapsed time in minutes.\n",
" elapsed = format_time(time.time() - t0)\n",
"\n",
" # Report progress.\n",
" print(' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))\n",
"\n",
" # Unpack this training batch from our dataloader.\n",
" #\n",
" # As we unpack the batch, we'll also copy each tensor to the GPU using the\n",
" # `to` method.\n",
" #\n",
" # `batch` contains three pytorch tensors:\n",
" # [0]: input ids\n",
" # [1]: attention masks\n",
" # [2]: labels\n",
" b_input_ids = batch[0].to(device)\n",
" b_input_mask = batch[1].to(device)\n",
" b_labels = batch[2].to(device)\n",
"\n",
" # Always clear any previously calculated gradients before performing a\n",
" # backward pass. PyTorch doesn't do this automatically because\n",
" # accumulating the gradients is \"convenient while training RNNs\".\n",
" # (source: https://stackoverflow.com/questions/48001598/why-do-we-need-to-call-zero-grad-in-pytorch)\n",
" model.zero_grad()\n",
"\n",
" # Perform a forward pass (evaluate the model on this training batch).\n",
" # In PyTorch, calling `model` will in turn call the model's `forward`\n",
" # function and pass down the arguments. The `forward` function is\n",
" # documented here:\n",
" # https://huggingface.co/transformers/model_doc/bert.html#bertforsequenceclassification\n",
" # The results are returned in a results object, documented here:\n",
" # https://huggingface.co/transformers/main_classes/output.html#transformers.modeling_outputs.SequenceClassifierOutput\n",
" # Specifically, we'll get the loss (because we provided labels) and the\n",
" # \"logits\"--the model outputs prior to activation.\n",
" result = model(b_input_ids,\n",
" token_type_ids=None,\n",
" attention_mask=b_input_mask,\n",
" labels=b_labels,\n",
" return_dict=True)\n",
"\n",
" loss = result.loss\n",
" logits = result.logits\n",
"\n",
" # Accumulate the training loss over all of the batches so that we can\n",
" # calculate the average loss at the end. `loss` is a Tensor containing a\n",
" # single value; the `.item()` function just returns the Python value\n",
" # from the tensor.\n",
" total_train_loss += loss.item()\n",
"\n",
" # Perform a backward pass to calculate the gradients.\n",
" loss.backward()\n",
"\n",
" # Clip the norm of the gradients to 1.0.\n",
" # This is to help prevent the \"exploding gradients\" problem.\n",
" torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
"\n",
" # Update parameters and take a step using the computed gradient.\n",
" # The optimizer dictates the \"update rule\"--how the parameters are\n",
" # modified based on their gradients, the learning rate, etc.\n",
" optimizer.step()\n",
"\n",
" # Update the learning rate.\n",
" scheduler.step()\n",
"\n",
" # Calculate the average loss over all of the batches.\n",
" avg_train_loss = total_train_loss / len(train_dataloader)\n",
"\n",
" # Measure how long this epoch took.\n",
" training_time = format_time(time.time() - t0)\n",
"\n",
" print(\"\")\n",
" print(\" Average training loss: {0:.2f}\".format(avg_train_loss))\n",
" print(\" Training epcoh took: {:}\".format(training_time))\n",
"\n",
" # ========================================\n",
" # Validation\n",
" # ========================================\n",
" # After the completion of each training epoch, measure our performance on\n",
" # our validation set.\n",
"\n",
" print(\"\")\n",
" print(\"Running Validation...\")\n",
"\n",
" t0 = time.time()\n",
"\n",
" # Put the model in evaluation mode--the dropout layers behave differently\n",
" # during evaluation.\n",
" model.eval()\n",
"\n",
" # Tracking variables\n",
" total_eval_accuracy = 0\n",
" total_eval_loss = 0\n",
" nb_eval_steps = 0\n",
"\n",
" # Evaluate data for one epoch\n",
" for batch in validation_dataloader:\n",
"\n",
" # Unpack this training batch from our dataloader.\n",
" #\n",
" # As we unpack the batch, we'll also copy each tensor to the GPU using\n",
" # the `to` method.\n",
" #\n",
" # `batch` contains three pytorch tensors:\n",
" # [0]: input ids\n",
" # [1]: attention masks\n",
" # [2]: labels\n",
" b_input_ids = batch[0].to(device)\n",
" b_input_mask = batch[1].to(device)\n",
" b_labels = batch[2].to(device)\n",
"\n",
" # Tell pytorch not to bother with constructing the compute graph during\n",
" # the forward pass, since this is only needed for backprop (training).\n",
" with torch.no_grad():\n",
"\n",
" # Forward pass, calculate logit predictions.\n",
" # token_type_ids is the same as the \"segment ids\", which\n",
" # differentiates sentence 1 and 2 in 2-sentence tasks.\n",
" result = model(b_input_ids,\n",
" token_type_ids=None,\n",
" attention_mask=b_input_mask,\n",
" labels=b_labels,\n",
" return_dict=True)\n",
"\n",
" # Get the loss and \"logits\" output by the model. The \"logits\" are the\n",
" # output values prior to applying an activation function like the\n",
" # softmax.\n",
" loss = result.loss\n",
" logits = result.logits\n",
"\n",
" # Accumulate the validation loss.\n",
" total_eval_loss += loss.item()\n",
"\n",
" # Move logits and labels to CPU\n",
" logits = logits.detach().cpu().numpy()\n",
" label_ids = b_labels.to('cpu').numpy()\n",
"\n",
" # Calculate the accuracy for this batch of test sentences, and\n",
" # accumulate it over all batches.\n",
" total_eval_accuracy += flat_accuracy(logits, label_ids)\n",
"\n",
"\n",
" # Report the final accuracy for this validation run.\n",
" avg_val_accuracy = total_eval_accuracy / len(validation_dataloader)\n",
" print(\" Accuracy: {0:.2f}\".format(avg_val_accuracy))\n",
"\n",
" # Calculate the average loss over all of the batches.\n",
" avg_val_loss = total_eval_loss / len(validation_dataloader)\n",
"\n",
" # Measure how long the validation run took.\n",
" validation_time = format_time(time.time() - t0)\n",
"\n",
" print(\" Validation Loss: {0:.2f}\".format(avg_val_loss))\n",
" print(\" Validation took: {:}\".format(validation_time))\n",
"\n",
" # Record all statistics from this epoch.\n",
" training_stats.append(\n",
" {\n",
" 'epoch': epoch_i + 1,\n",
" 'Training Loss': avg_train_loss,\n",
" 'Valid. Loss': avg_val_loss,\n",
" 'Valid. Accur.': avg_val_accuracy,\n",
" 'Training Time': training_time,\n",
" 'Validation Time': validation_time\n",
" }\n",
" )\n",
"\n",
"print(\"\")\n",
"print(\"Training complete!\")\n",
"\n",
"print(\"Total training took {:} (h:mm:ss)\".format(format_time(time.time()-total_t0)))"
],
"metadata": {
"id": "coCQazMoLZXx",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "f5f181e1-40fa-46a0-c883-8ef21900ac01"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"======== Epoch 1 / 4 ========\n",
"Training...\n",
" Batch 40 of 1,204. Elapsed: 0:00:54.\n",
" Batch 80 of 1,204. Elapsed: 0:01:48.\n",
" Batch 120 of 1,204. Elapsed: 0:02:43.\n",
" Batch 160 of 1,204. Elapsed: 0:03:37.\n",
" Batch 200 of 1,204. Elapsed: 0:04:31.\n",
" Batch 240 of 1,204. Elapsed: 0:05:26.\n",
" Batch 280 of 1,204. Elapsed: 0:06:20.\n",
" Batch 320 of 1,204. Elapsed: 0:07:14.\n",
" Batch 360 of 1,204. Elapsed: 0:08:08.\n",
" Batch 400 of 1,204. Elapsed: 0:09:03.\n",
" Batch 440 of 1,204. Elapsed: 0:09:57.\n",
" Batch 480 of 1,204. Elapsed: 0:10:51.\n",
" Batch 520 of 1,204. Elapsed: 0:11:46.\n",
" Batch 560 of 1,204. Elapsed: 0:12:40.\n",
" Batch 600 of 1,204. Elapsed: 0:13:34.\n",
" Batch 640 of 1,204. Elapsed: 0:14:28.\n",
" Batch 680 of 1,204. Elapsed: 0:15:23.\n",
" Batch 720 of 1,204. Elapsed: 0:16:17.\n",
" Batch 760 of 1,204. Elapsed: 0:17:11.\n",
" Batch 800 of 1,204. Elapsed: 0:18:06.\n",
" Batch 840 of 1,204. Elapsed: 0:19:00.\n",
" Batch 880 of 1,204. Elapsed: 0:19:54.\n",
" Batch 920 of 1,204. Elapsed: 0:20:49.\n",
" Batch 960 of 1,204. Elapsed: 0:21:43.\n",
" Batch 1,000 of 1,204. Elapsed: 0:22:37.\n",
" Batch 1,040 of 1,204. Elapsed: 0:23:32.\n",
" Batch 1,080 of 1,204. Elapsed: 0:24:26.\n",
" Batch 1,120 of 1,204. Elapsed: 0:25:20.\n",
" Batch 1,160 of 1,204. Elapsed: 0:26:15.\n",
" Batch 1,200 of 1,204. Elapsed: 0:27:09.\n",
"\n",
" Average training loss: 0.20\n",
" Training epcoh took: 0:27:14\n",
"\n",
"Running Validation...\n",
" Accuracy: 0.93\n",
" Validation Loss: 0.15\n",
" Validation took: 0:00:12\n",
"\n",
"======== Epoch 2 / 4 ========\n",
"Training...\n",
" Batch 40 of 1,204. Elapsed: 0:00:54.\n",
" Batch 80 of 1,204. Elapsed: 0:01:49.\n",
" Batch 120 of 1,204. Elapsed: 0:02:43.\n",
" Batch 160 of 1,204. Elapsed: 0:03:37.\n",
" Batch 200 of 1,204. Elapsed: 0:04:32.\n",
" Batch 240 of 1,204. Elapsed: 0:05:26.\n",
" Batch 280 of 1,204. Elapsed: 0:06:20.\n",
" Batch 320 of 1,204. Elapsed: 0:07:15.\n",
" Batch 360 of 1,204. Elapsed: 0:08:09.\n",
" Batch 400 of 1,204. Elapsed: 0:09:03.\n",
" Batch 440 of 1,204. Elapsed: 0:09:58.\n",
" Batch 480 of 1,204. Elapsed: 0:10:52.\n",
" Batch 520 of 1,204. Elapsed: 0:11:46.\n",
" Batch 560 of 1,204. Elapsed: 0:12:41.\n",
" Batch 600 of 1,204. Elapsed: 0:13:35.\n",
" Batch 640 of 1,204. Elapsed: 0:14:29.\n",
" Batch 680 of 1,204. Elapsed: 0:15:24.\n",
" Batch 720 of 1,204. Elapsed: 0:16:18.\n",
" Batch 760 of 1,204. Elapsed: 0:17:12.\n",
" Batch 800 of 1,204. Elapsed: 0:18:07.\n",
" Batch 840 of 1,204. Elapsed: 0:19:01.\n",
" Batch 880 of 1,204. Elapsed: 0:19:55.\n",
" Batch 920 of 1,204. Elapsed: 0:20:50.\n",
" Batch 960 of 1,204. Elapsed: 0:21:44.\n",
" Batch 1,000 of 1,204. Elapsed: 0:22:38.\n",
" Batch 1,040 of 1,204. Elapsed: 0:23:33.\n",
" Batch 1,080 of 1,204. Elapsed: 0:24:27.\n",
" Batch 1,120 of 1,204. Elapsed: 0:25:21.\n",
" Batch 1,160 of 1,204. Elapsed: 0:26:15.\n",
" Batch 1,200 of 1,204. Elapsed: 0:27:10.\n",
"\n",
" Average training loss: 0.12\n",
" Training epcoh took: 0:27:15\n",
"\n",
"Running Validation...\n",
" Accuracy: 0.95\n",
" Validation Loss: 0.17\n",
" Validation took: 0:00:12\n",
"\n",
"======== Epoch 3 / 4 ========\n",
"Training...\n",
" Batch 40 of 1,204. Elapsed: 0:00:54.\n",
" Batch 80 of 1,204. Elapsed: 0:01:49.\n",
" Batch 120 of 1,204. Elapsed: 0:02:43.\n",
" Batch 160 of 1,204. Elapsed: 0:03:37.\n",
" Batch 200 of 1,204. Elapsed: 0:04:32.\n",
" Batch 240 of 1,204. Elapsed: 0:05:26.\n",
" Batch 280 of 1,204. Elapsed: 0:06:20.\n",
" Batch 320 of 1,204. Elapsed: 0:07:15.\n",
" Batch 360 of 1,204. Elapsed: 0:08:09.\n",
" Batch 400 of 1,204. Elapsed: 0:09:03.\n",
" Batch 440 of 1,204. Elapsed: 0:09:57.\n",
" Batch 480 of 1,204. Elapsed: 0:10:52.\n",
" Batch 520 of 1,204. Elapsed: 0:11:46.\n",
" Batch 560 of 1,204. Elapsed: 0:12:40.\n",
" Batch 600 of 1,204. Elapsed: 0:13:35.\n",
" Batch 640 of 1,204. Elapsed: 0:14:29.\n",
" Batch 680 of 1,204. Elapsed: 0:15:23.\n",
" Batch 720 of 1,204. Elapsed: 0:16:18.\n",
" Batch 760 of 1,204. Elapsed: 0:17:12.\n",
" Batch 800 of 1,204. Elapsed: 0:18:06.\n",
" Batch 840 of 1,204. Elapsed: 0:19:01.\n",
" Batch 880 of 1,204. Elapsed: 0:19:55.\n",
" Batch 920 of 1,204. Elapsed: 0:20:49.\n",
" Batch 960 of 1,204. Elapsed: 0:21:44.\n",
" Batch 1,000 of 1,204. Elapsed: 0:22:38.\n",
" Batch 1,040 of 1,204. Elapsed: 0:23:32.\n",
" Batch 1,080 of 1,204. Elapsed: 0:24:27.\n",
" Batch 1,120 of 1,204. Elapsed: 0:25:21.\n",
" Batch 1,160 of 1,204. Elapsed: 0:26:15.\n",
" Batch 1,200 of 1,204. Elapsed: 0:27:10.\n",
"\n",
" Average training loss: 0.06\n",
" Training epcoh took: 0:27:14\n",
"\n",
"Running Validation...\n",
" Accuracy: 0.96\n",
" Validation Loss: 0.16\n",
" Validation took: 0:00:12\n",
"\n",
"======== Epoch 4 / 4 ========\n",
"Training...\n",
" Batch 40 of 1,204. Elapsed: 0:00:54.\n",
" Batch 80 of 1,204. Elapsed: 0:01:49.\n",
" Batch 120 of 1,204. Elapsed: 0:02:43.\n",
" Batch 160 of 1,204. Elapsed: 0:03:37.\n",
" Batch 200 of 1,204. Elapsed: 0:04:32.\n",
" Batch 240 of 1,204. Elapsed: 0:05:26.\n",
" Batch 280 of 1,204. Elapsed: 0:06:20.\n",
" Batch 320 of 1,204. Elapsed: 0:07:15.\n",
" Batch 360 of 1,204. Elapsed: 0:08:09.\n",
" Batch 400 of 1,204. Elapsed: 0:09:03.\n",
" Batch 440 of 1,204. Elapsed: 0:09:58.\n",
" Batch 480 of 1,204. Elapsed: 0:10:52.\n",
" Batch 520 of 1,204. Elapsed: 0:11:46.\n",
" Batch 560 of 1,204. Elapsed: 0:12:40.\n",
" Batch 600 of 1,204. Elapsed: 0:13:35.\n",
" Batch 640 of 1,204. Elapsed: 0:14:29.\n",
" Batch 680 of 1,204. Elapsed: 0:15:23.\n",
" Batch 720 of 1,204. Elapsed: 0:16:18.\n",
" Batch 760 of 1,204. Elapsed: 0:17:12.\n",
" Batch 800 of 1,204. Elapsed: 0:18:06.\n",
" Batch 840 of 1,204. Elapsed: 0:19:01.\n",
" Batch 880 of 1,204. Elapsed: 0:19:55.\n",
" Batch 920 of 1,204. Elapsed: 0:20:49.\n",
" Batch 960 of 1,204. Elapsed: 0:21:44.\n",
" Batch 1,000 of 1,204. Elapsed: 0:22:38.\n",
" Batch 1,040 of 1,204. Elapsed: 0:23:32.\n",
" Batch 1,080 of 1,204. Elapsed: 0:24:27.\n",
" Batch 1,120 of 1,204. Elapsed: 0:25:21.\n",
" Batch 1,160 of 1,204. Elapsed: 0:26:15.\n",
" Batch 1,200 of 1,204. Elapsed: 0:27:09.\n",
"\n",
" Average training loss: 0.03\n",
" Training epcoh took: 0:27:14\n",
"\n",
"Running Validation...\n",
" Accuracy: 0.96\n",
" Validation Loss: 0.19\n",
" Validation took: 0:00:12\n",
"\n",
"Training complete!\n",
"Total training took 1:49:46 (h:mm:ss)\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"\n",
"# Display floats with two decimal places.\n",
"# pd.set_option('precision', 2)\n",
"pd.options.display.max_rows = 5\n",
"# pd.options.precision = 2\n",
"print(training_stats)\n",
"# Create a DataFrame from our training statistics.\n",
"df_stats = pd.DataFrame(data=training_stats)\n",
"\n",
"# Use the 'epoch' as the row index.\n",
"df_stats = df_stats.set_index('epoch')\n",
"\n",
"# A hack to force the column headers to wrap.\n",
"#df = df.style.set_table_styles([dict(selector=\"th\",props=[('max-width', '70px')])])\n",
"\n",
"# Display the table.\n",
"df_stats"
],
"metadata": {
"id": "Hyz17eQ0Leez",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 244
},
"outputId": "5c7b9dce-7c27-4d61-c2de-57dd15092410"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"[{'epoch': 1, 'Training Loss': 0.20411855078583113, 'Valid. Loss': 0.14585627406835555, 'Valid. Accur.': 0.932, 'Training Time': '0:27:14', 'Validation Time': '0:00:12'}, {'epoch': 2, 'Training Loss': 0.1152435676600344, 'Valid. Loss': 0.1749358731787652, 'Valid. Accur.': 0.9484999999999999, 'Training Time': '0:27:15', 'Validation Time': '0:00:12'}, {'epoch': 3, 'Training Loss': 0.06414671170705484, 'Valid. Loss': 0.15540768602164462, 'Valid. Accur.': 0.9570000000000001, 'Training Time': '0:27:14', 'Validation Time': '0:00:12'}, {'epoch': 4, 'Training Loss': 0.031750682121859414, 'Valid. Loss': 0.1890620315662818, 'Valid. Accur.': 0.9620000000000001, 'Training Time': '0:27:14', 'Validation Time': '0:00:12'}]\n"
]
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
" Training Loss Valid. Loss Valid. Accur. Training Time Validation Time\n",
"epoch \n",
"1 0.204119 0.145856 0.9320 0:27:14 0:00:12\n",
"2 0.115244 0.174936 0.9485 0:27:15 0:00:12\n",
"3 0.064147 0.155408 0.9570 0:27:14 0:00:12\n",
"4 0.031751 0.189062 0.9620 0:27:14 0:00:12"
],
"text/html": [
"\n",
"\n",
" <div id=\"df-3eeb6233-d59e-42b6-a258-cca41fb9dfcb\">\n",
" <div class=\"colab-df-container\">\n",
" <div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Training Loss</th>\n",
" <th>Valid. Loss</th>\n",
" <th>Valid. Accur.</th>\n",
" <th>Training Time</th>\n",
" <th>Validation Time</th>\n",
" </tr>\n",
" <tr>\n",
" <th>epoch</th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" <th></th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>0.204119</td>\n",
" <td>0.145856</td>\n",
" <td>0.9320</td>\n",
" <td>0:27:14</td>\n",
" <td>0:00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>0.115244</td>\n",
" <td>0.174936</td>\n",
" <td>0.9485</td>\n",
" <td>0:27:15</td>\n",
" <td>0:00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>0.064147</td>\n",
" <td>0.155408</td>\n",
" <td>0.9570</td>\n",
" <td>0:27:14</td>\n",
" <td>0:00:12</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>0.031751</td>\n",
" <td>0.189062</td>\n",
" <td>0.9620</td>\n",
" <td>0:27:14</td>\n",
" <td>0:00:12</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>\n",
" <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-3eeb6233-d59e-42b6-a258-cca41fb9dfcb')\"\n",
" title=\"Convert this dataframe to an interactive table.\"\n",
" style=\"display:none;\">\n",
"\n",
" <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
" <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
" </svg>\n",
" </button>\n",
"\n",
"\n",
"\n",
" <div id=\"df-0e0addba-6b7a-492f-bc65-42df9cc9f206\">\n",
" <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-0e0addba-6b7a-492f-bc65-42df9cc9f206')\"\n",
" title=\"Suggest charts.\"\n",
" style=\"display:none;\">\n",
"\n",
"<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
" width=\"24px\">\n",
" <g>\n",
" <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
" </g>\n",
"</svg>\n",
" </button>\n",
" </div>\n",
"\n",
"<style>\n",
" .colab-df-quickchart {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-quickchart:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-quickchart {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-quickchart:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
"</style>\n",
"\n",
" <script>\n",
" async function quickchart(key) {\n",
" const containerElement = document.querySelector('#' + key);\n",
" const charts = await google.colab.kernel.invokeFunction(\n",
" 'suggestCharts', [key], {});\n",
" }\n",
" </script>\n",
"\n",
" <script>\n",
"\n",
"function displayQuickchartButton(domScope) {\n",
" let quickchartButtonEl =\n",
" domScope.querySelector('#df-0e0addba-6b7a-492f-bc65-42df9cc9f206 button.colab-df-quickchart');\n",
" quickchartButtonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"}\n",
"\n",
" displayQuickchartButton(document);\n",
" </script>\n",
" <style>\n",
" .colab-df-container {\n",
" display:flex;\n",
" flex-wrap:wrap;\n",
" gap: 12px;\n",
" }\n",
"\n",
" .colab-df-convert {\n",
" background-color: #E8F0FE;\n",
" border: none;\n",
" border-radius: 50%;\n",
" cursor: pointer;\n",
" display: none;\n",
" fill: #1967D2;\n",
" height: 32px;\n",
" padding: 0 0 0 0;\n",
" width: 32px;\n",
" }\n",
"\n",
" .colab-df-convert:hover {\n",
" background-color: #E2EBFA;\n",
" box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
" fill: #174EA6;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert {\n",
" background-color: #3B4455;\n",
" fill: #D2E3FC;\n",
" }\n",
"\n",
" [theme=dark] .colab-df-convert:hover {\n",
" background-color: #434B5C;\n",
" box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
" filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
" fill: #FFFFFF;\n",
" }\n",
" </style>\n",
"\n",
" <script>\n",
" const buttonEl =\n",
" document.querySelector('#df-3eeb6233-d59e-42b6-a258-cca41fb9dfcb button.colab-df-convert');\n",
" buttonEl.style.display =\n",
" google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
"\n",
" async function convertToInteractive(key) {\n",
" const element = document.querySelector('#df-3eeb6233-d59e-42b6-a258-cca41fb9dfcb');\n",
" const dataTable =\n",
" await google.colab.kernel.invokeFunction('convertToInteractive',\n",
" [key], {});\n",
" if (!dataTable) return;\n",
"\n",
" const docLinkHtml = 'Like what you see? Visit the ' +\n",
" '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
" + ' to learn more about interactive tables.';\n",
" element.innerHTML = '';\n",
" dataTable['output_type'] = 'display_data';\n",
" await google.colab.output.renderOutput(dataTable, element);\n",
" const docLink = document.createElement('div');\n",
" docLink.innerHTML = docLinkHtml;\n",
" element.appendChild(docLink);\n",
" }\n",
" </script>\n",
" </div>\n",
" </div>\n"
]
},
"metadata": {},
"execution_count": 62
}
]
},
{
"cell_type": "code",
"source": [
"%matplotlib inline"
],
"metadata": {
"id": "QKrfLhYvLh5Q"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import seaborn as sns\n",
"\n",
"# Use plot styling from seaborn.\n",
"sns.set(style='darkgrid')\n",
"\n",
"# Increase the plot size and font size.\n",
"sns.set(font_scale=1.5)\n",
"plt.rcParams[\"figure.figsize\"] = (12,6)\n",
"\n",
"# Plot the learning curve.\n",
"plt.plot(df_stats['Training Loss'], 'b-o', label=\"Training\")\n",
"plt.plot(df_stats['Valid. Loss'], 'g-o', label=\"Validation\")\n",
"\n",
"# Label the plot.\n",
"plt.title(\"Training & Validation Loss\")\n",
"plt.xlabel(\"Epoch\")\n",
"plt.ylabel(\"Loss\")\n",
"plt.legend()\n",
"# plt.xticks([1, 2, 3, 4])\n",
"plt.xticks(list(range(1, epochs)))\n",
"\n",
"plt.show()"
],
"metadata": {
"id": "V1rjGrcQLm7U",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 592
},
"outputId": "6d0e2ed3-ab3f-4b5f-adf7-94b92da7457e"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1200x600 with 1 Axes>"
],
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"Saving the model"
],
"metadata": {
"id": "rJ6ylBCnMGBN"
}
},
{
"cell_type": "code",
"source": [
"import os\n",
"\n",
"# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()\n",
"\n",
"output_dir = './cross_entropy_model/'\n",
"\n",
"# Create output directory if needed\n",
"if not os.path.exists(output_dir):\n",
" os.makedirs(output_dir)\n",
"\n",
"print(\"Saving model to %s\" % output_dir)\n",
"\n",
"# Save a trained model, configuration and tokenizer using `save_pretrained()`.\n",
"# They can then be reloaded using `from_pretrained()`\n",
"model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training\n",
"model_to_save.save_pretrained(output_dir)\n",
"tokenizer.save_pretrained(output_dir)\n",
"\n",
"# Good practice: save your training arguments together with the trained model\n",
"# torch.save(args, os.path.join(output_dir, 'training_args.bin'))\n",
"\n",
"\n",
"# Copy the model files to a directory in your Google Drive.\n",
"!cp -r ./cross_entropy_model/ ./drive/MyDrive"
],
"metadata": {
"id": "OdCk0ZDvMIUL",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "46abf954-6e03-4677-879f-420b74c9e4d1"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Saving model to ./cross_entropy_model/\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"#### Testing After Cross-entropy training"
],
"metadata": {
"id": "so-0GYfcLuaf"
}
},
{
"cell_type": "code",
"source": [
"import pandas as pd\n",
"\n",
"# Load the test dataset into a pandas dataframe.\n",
"# df = pd.read_csv(\"./cola_public/raw/out_of_domain_dev.tsv\", delimiter='\\t', header=None, names=['sentence_source', 'label', 'label_notes', 'sentence'])\n",
"\n",
"# Report the number of sentences.\n",
"print('Number of test sentences: {:,}\\n'.format(test_df.shape[0]))\n",
"\n",
"# Create sentence and label lists\n",
"sentences = test_df.sentence.values\n",
"labels = test_df.label.values\n",
"# sentences=[\"go fuck yourself\",\n",
"# \"this is crap\",\n",
"# \"thank you for the information\",\n",
"# \"yeah that sucked, fixed, Done.\",\n",
"# \"Crap, this is an artifact of a previous revision. It's simply the last time a change was made to Tuskar's cloud.\",\n",
"# \"Ah damn I misread the bug -_-\",\n",
"# \"wtf...\",\n",
"# \"I appreciate your help.\",\n",
"# \"fuuuuck\",\n",
"# \"what the f*ck\",\n",
"# \"absolute shit\",\n",
"# \"Get the hell outta here\",\n",
"# \"shi*tty code\",\n",
"# \"you are an absolute b!tch\",\n",
"# \"Nothing particular to worry about\"]\n",
"\n",
"# labels = [1,1,0,1,1,1,0,0,1,1,1,1,1,1,0]\n",
"# Tokenize all of the sentences and map the tokens to thier word IDs.\n",
"input_ids = []\n",
"attention_masks = []\n",
"\n",
"# For every sentence...\n",
"for sent in sentences:\n",
" # `encode_plus` will:\n",
" # (1) Tokenize the sentence.\n",
" # (2) Prepend the `[CLS]` token to the start.\n",
" # (3) Append the `[SEP]` token to the end.\n",
" # (4) Map tokens to their IDs.\n",
" # (5) Pad or truncate the sentence to `max_length`\n",
" # (6) Create attention masks for [PAD] tokens.\n",
" encoded_dict = tokenizer.encode_plus(\n",
" str(sent), # Sentence to encode.\n",
" add_special_tokens = True, # Add '[CLS]' and '[SEP]'\n",
" max_length = 512, # Pad & truncate all sentences.\n",
" pad_to_max_length = True,\n",
" return_attention_mask = True, # Construct attn. masks.\n",
" return_tensors = 'pt', # Return pytorch tensors.\n",
" )\n",
"\n",
" # Add the encoded sentence to the list.\n",
" input_ids.append(encoded_dict['input_ids'])\n",
"\n",
" # And its attention mask (simply differentiates padding from non-padding).\n",
" attention_masks.append(encoded_dict['attention_mask'])\n",
"\n",
"# Convert the lists into tensors.\n",
"input_ids = torch.cat(input_ids, dim=0)\n",
"attention_masks = torch.cat(attention_masks, dim=0)\n",
"labels = torch.tensor(labels)\n",
"\n",
"# Set the batch size.\n",
"batch_size = 16\n",
"\n",
"# Create the DataLoader.\n",
"prediction_data = TensorDataset(input_ids, attention_masks, labels)\n",
"prediction_sampler = SequentialSampler(prediction_data)\n",
"prediction_dataloader = DataLoader(prediction_data, sampler=prediction_sampler, batch_size=batch_size)\n",
"\n",
"print('Predicting labels for {:,} test sentences...'.format(len(input_ids)))\n"
],
"metadata": {
"id": "Tx1rw9UuLyre",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "b504e8af-3403-496a-e9b9-dbfd9145923f"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Predicting labels for 15 test sentences...\n"
]
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/transformers/tokenization_utils_base.py:2393: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n",
" warnings.warn(\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Prediction on test set\n",
"print(\"Running Test set...\")\n",
"\n",
"t0 = time.time()\n",
"\n",
"# Put the model in evaluation mode--the dropout layers behave differently\n",
"# during evaluation.\n",
"model.eval()\n",
"\n",
"# Tracking variables\n",
"total_eval_accuracy = 0\n",
"total_eval_loss = 0\n",
"nb_eval_steps = 0\n",
"\n",
"\n",
"# Put model in evaluation mode\n",
"model.eval()\n",
"\n",
"# Tracking variables\n",
"predictions , true_labels = [], []\n",
"\n",
"# Predict\n",
"for batch in prediction_dataloader:\n",
" # Add batch to GPU\n",
" batch = tuple(t.to(device) for t in batch)\n",
"\n",
" # Unpack the inputs from our dataloader\n",
" b_input_ids, b_input_mask, b_labels = batch\n",
"\n",
" # Telling the model not to compute or store gradients, saving memory and\n",
" # speeding up prediction\n",
" with torch.no_grad():\n",
" # Forward pass, calculate logit predictions.\n",
" # result = model(b_input_ids,\n",
" # token_type_ids=None,\n",
" # attention_mask=b_input_mask,\n",
" # labels=b_labels,\n",
" # return_dict=True)\n",
"\n",
" result = model(b_input_ids,\n",
" token_type_ids=None,\n",
" attention_mask=b_input_mask,\n",
" return_dict=True)\n",
"\n",
" # Get the loss and \"logits\" output by the model. The \"logits\" are the\n",
" # output values prior to applying an activation function like the\n",
" # softmax.\n",
" # loss = result.loss\n",
" logits = result.logits\n",
"\n",
" # Accumulate the validation loss.\n",
" # total_eval_loss += loss.item()\n",
"\n",
" # Move logits and labels to CPU\n",
" logits = logits.detach().cpu().numpy()\n",
" label_ids = b_labels.to('cpu').numpy()\n",
"\n",
" # Store predictions and true labels\n",
" predictions.append(logits)\n",
" true_labels.append(label_ids)\n",
"\n",
" # Calculate the accuracy for this batch of test sentences, and\n",
" # accumulate it over all batches.\n",
" total_eval_accuracy += flat_accuracy(logits, label_ids)\n",
"\n",
"\n",
"print(' DONE.')\n",
"# Report the final accuracy for this run.\n",
"avg_val_accuracy = total_eval_accuracy / len(prediction_dataloader) * 100.0\n",
"print(\" Accuracy: {0:.2f}\".format(avg_val_accuracy))\n"
],
"metadata": {
"id": "4IanskVZL4Qn",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "e3707fd9-74f4-432b-ba42-fd15391f50db"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Running Test set...\n",
" DONE.\n",
" Accuracy: 86.67\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from sklearn.metrics import matthews_corrcoef\n",
"\n",
"matthews_set = []\n",
"\n",
"# Evaluate each test batch using Matthew's correlation coefficient\n",
"print('Calculating Matthews Corr. Coef. for each batch...')\n",
"final_labels = []\n",
"# For each input batch...\n",
"for i in range(len(true_labels)):\n",
"\n",
" # The predictions for this batch are a 2-column ndarray (one column for \"0\"\n",
" # and one column for \"1\"). Pick the label with the highest value and turn this\n",
" # in to a list of 0s and 1s.\n",
" pred_labels_i = np.argmax(predictions[i], axis=1).flatten()\n",
" final_labels.append(pred_labels_i)\n",
" # Calculate and store the coef for this batch.\n",
" matthews = matthews_corrcoef(true_labels[i], pred_labels_i)\n",
" matthews_set.append(matthews)\n",
"\n",
"print(final_labels)"
],
"metadata": {
"id": "CHmFFznvL5L4",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "8dac4d8f-642c-40ac-c305-e419dc176596"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Calculating Matthews Corr. Coef. for each batch...\n",
"[array([1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0])]\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"from sklearn.metrics import f1_score\n",
"\n",
"count = 0\n",
"# Calculate accuracy\n",
"for i in range(len(true_labels[0])):\n",
" if true_labels[0][i] == final_labels[0][i]:\n",
" count = count + 1\n",
"\n",
"\n",
"print(\" Accuracy: {0:.2f}\".format(count/len(true_labels[0]) * 100.0))\n",
"\n",
"\n",
"def f1(actual, predicted, label):\n",
"\n",
" \"\"\" A helper function to calculate f1-score for the given `label` \"\"\"\n",
" # F1 = 2 * (precision * recall) / (precision + recall)\n",
" tp = np.sum((actual==label) & (predicted==label))\n",
" fp = np.sum((actual!=label) & (predicted==label))\n",
" fn = np.sum((predicted!=label) & (actual==label))\n",
"\n",
" precision = tp/(tp+fp)\n",
" recall = tp/(tp+fn)\n",
" f1 = 2 * (precision * recall) / (precision + recall)\n",
" return f1\n",
"\n",
"def f1_macro(actual, predicted):\n",
" # `macro` f1- unweighted mean of f1 per label\n",
" return np.mean([f1(actual, predicted, label)\n",
" for label in np.unique(actual)])\n",
"\n",
"\n",
"print(\" F1 Score: {0:.2f}\".format(f1_macro(true_labels[0], final_labels[0])))\n",
"\n",
"#define array of actual classes\n",
"# actual = np.repeat(true_labels, repeats=[160, 240])\n",
"\n",
"#define array of predicted classes\n",
"# pred = np.repeat(final_labels, repeats=[120, 40, 70, 170])\n",
"\n",
"#calculate F1 score\n",
"# f1_score(actual, pred)"
],
"metadata": {
"id": "6QmtX2DCL75J",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "7ad9dd90-8ed8-4e73-d28a-d3519373cfec"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
" Accuracy: 86.67\n",
" F1 Score: 0.85\n"
]
}
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment