Skip to content

Instantly share code, notes, and snippets.

@swati210994
Last active January 14, 2024 21:24
Show Gist options
  • Save swati210994/963e084e8b76e8b5065a360d6d0741a0 to your computer and use it in GitHub Desktop.
Save swati210994/963e084e8b76e8b5065a360d6d0741a0 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Import all packages ###"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import tensorflow_hub as hub\n",
"import pandas as pd\n",
"from sklearn.model_selection import train_test_split\n",
"import numpy as np\n",
"import re\n",
"import unicodedata\n",
"import nltk\n",
"from nltk.corpus import stopwords\n",
"from tensorflow import keras\n",
"from tensorflow.keras.layers import Dense,Dropout, Input\n",
"from tqdm import tqdm\n",
"import pickle\n",
"from sklearn.metrics import confusion_matrix,f1_score,classification_report\n",
"import matplotlib.pyplot as plt\n",
"import itertools\n",
"from sklearn.utils import shuffle\n",
"from tensorflow.keras import regularizers\n",
"from transformers import *\n",
"from transformers import BertTokenizer, TFBertModel, BertConfig,TFDistilBertModel,DistilBertTokenizer,DistilBertConfig"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Preprocessing and cleaning functions ###"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"def unicode_to_ascii(s):\n",
" return ''.join(c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn')\n",
"\n",
"def clean_stopwords_shortwords(w):\n",
" stopwords_list=stopwords.words('english')\n",
" words = w.split() \n",
" clean_words = [word for word in words if (word not in stopwords_list) and len(word) > 2]\n",
" return \" \".join(clean_words) \n",
"\n",
"def preprocess_sentence(w):\n",
" w = unicode_to_ascii(w.lower().strip())\n",
" w = re.sub(r\"([?.!,¿])\", r\" \", w)\n",
" w = re.sub(r'[\" \"]+', \" \", w)\n",
" w = re.sub(r\"[^a-zA-Z?.!,¿]+\", \" \", w)\n",
" w=clean_stopwords_shortwords(w)\n",
" w=re.sub(r'@\\w+', '',w)\n",
" return w"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Reading and Cleaning the Dataset ###"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>v1</th>\n",
" <th>v2</th>\n",
" <th>Unnamed: 2</th>\n",
" <th>Unnamed: 3</th>\n",
" <th>Unnamed: 4</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>ham</td>\n",
" <td>Go until jurong point, crazy.. Available only ...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>ham</td>\n",
" <td>Ok lar... Joking wif u oni...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>spam</td>\n",
" <td>Free entry in 2 a wkly comp to win FA Cup fina...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>ham</td>\n",
" <td>U dun say so early hor... U c already then say...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>ham</td>\n",
" <td>Nah I don't think he goes to usf, he lives aro...</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" v1 v2 Unnamed: 2 \\\n",
"0 ham Go until jurong point, crazy.. Available only ... NaN \n",
"1 ham Ok lar... Joking wif u oni... NaN \n",
"2 spam Free entry in 2 a wkly comp to win FA Cup fina... NaN \n",
"3 ham U dun say so early hor... U c already then say... NaN \n",
"4 ham Nah I don't think he goes to usf, he lives aro... NaN \n",
"\n",
" Unnamed: 3 Unnamed: 4 \n",
"0 NaN NaN \n",
"1 NaN NaN \n",
"2 NaN NaN \n",
"3 NaN NaN \n",
"4 NaN NaN "
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data_file='./data/spam.csv'\n",
"data=pd.read_csv(data_file,encoding='ISO-8859-1')\n",
"\n",
"data.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> Removing Unnamed Columns, dropping NaN data and resetting the index after dropping some rows/columns containing NaN dataset and finally shuffling the dataset"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"File has 5572 rows and 2 columns\n",
"File has 5572 rows and 2 columns\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>v1</th>\n",
" <th>v2</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2011</th>\n",
" <td>ham</td>\n",
" <td>Do whatever you want. You know what the rules ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2763</th>\n",
" <td>ham</td>\n",
" <td>Say this slowly.? GOD,I LOVE YOU &amp;amp; I NEED ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>388</th>\n",
" <td>spam</td>\n",
" <td>4mths half price Orange line rental &amp; latest c...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1938</th>\n",
" <td>ham</td>\n",
" <td>Excellent! Are you ready to moan and scream in...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1903</th>\n",
" <td>spam</td>\n",
" <td>Free entry in 2 a weekly comp for a chance to ...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" v1 v2\n",
"2011 ham Do whatever you want. You know what the rules ...\n",
"2763 ham Say this slowly.? GOD,I LOVE YOU &amp; I NEED ...\n",
"388 spam 4mths half price Orange line rental & latest c...\n",
"1938 ham Excellent! Are you ready to moan and scream in...\n",
"1903 spam Free entry in 2 a weekly comp for a chance to ..."
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data = data.loc[:, ~data.columns.str.contains('Unnamed: 2', case=False)] \n",
"data = data.loc[:, ~data.columns.str.contains('Unnamed: 3', case=False)] \n",
"data = data.loc[:, ~data.columns.str.contains('Unnamed: 4', case=False)] \n",
"print('File has {} rows and {} columns'.format(data.shape[0],data.shape[1]))\n",
"data=data.dropna()\n",
"data=data.reset_index(drop=True)\n",
"print('File has {} rows and {} columns'.format(data.shape[0],data.shape[1]))\n",
"data = shuffle(data)\n",
"\n",
"data.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
" > Rename v1:label and v2:text, converting 'ham' labels to '0' and 'spam' to '1', saving it to the 'gt' (ground truth) column and applying the preprocess function to the dataset"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Available labels: ['ham' 'spam']\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>label</th>\n",
" <th>text</th>\n",
" <th>gt</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>2011</th>\n",
" <td>ham</td>\n",
" <td>whatever want know rules talk earlier week sta...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2763</th>\n",
" <td>ham</td>\n",
" <td>say slowly god love amp need clean heart blood...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>388</th>\n",
" <td>spam</td>\n",
" <td>mths half price orange line rental latest came...</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1938</th>\n",
" <td>ham</td>\n",
" <td>excellent ready moan scream ecstasy</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1903</th>\n",
" <td>spam</td>\n",
" <td>free entry weekly comp chance win ipod txt pod...</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" label text gt\n",
"2011 ham whatever want know rules talk earlier week sta... 0\n",
"2763 ham say slowly god love amp need clean heart blood... 0\n",
"388 spam mths half price orange line rental latest came... 1\n",
"1938 ham excellent ready moan scream ecstasy 0\n",
"1903 spam free entry weekly comp chance win ipod txt pod... 1"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"data=data.rename(columns = {'v1': 'label', 'v2': 'text'}, inplace = False)\n",
"\n",
"data['gt'] = data['label'].map({'ham':0,'spam':1})\n",
"\n",
"print('Available labels: ',data.label.unique())\n",
"data['text']=data['text'].map(preprocess_sentence)\n",
"\n",
"num_classes=len(data.label.unique())\n",
"\n",
"data.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Loading DistilBERT Tokenizer and the DistilBERT model ###"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"dbert_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"dbert_model = TFDistilBertModel.from_pretrained('distilbert-base-uncased')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Preparing input for the model ###"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(5572, 5572)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"max_len=32\n",
"sentences=data['text']\n",
"labels=data['gt']\n",
"len(sentences),len(labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Let's take a sentence from the dataset and understand the input and output of the DistilBERT #### "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> Tokenized sentence "
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['ju',\n",
" '##rong',\n",
" 'point',\n",
" 'crazy',\n",
" 'available',\n",
" 'bug',\n",
" '##is',\n",
" 'great',\n",
" 'world',\n",
" 'buffet',\n",
" 'ci',\n",
" '##ne',\n",
" 'got',\n",
" 'amore',\n",
" 'wat']"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dbert_tokenizer.tokenize(sentences[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> Input ids and the attention masks from the tokenizer "
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': [101,\n",
" 18414,\n",
" 17583,\n",
" 2391,\n",
" 4689,\n",
" 2800,\n",
" 11829,\n",
" 2483,\n",
" 2307,\n",
" 2088,\n",
" 28305,\n",
" 25022,\n",
" 2638,\n",
" 2288,\n",
" 26297,\n",
" 28194,\n",
" 102,\n",
" 0,\n",
" 0,\n",
" 0],\n",
" 'attention_mask': [1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 1,\n",
" 0,\n",
" 0,\n",
" 0]}"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dbert_inp=dbert_tokenizer.encode_plus(sentences[0],add_special_tokens = True,max_length =20,pad_to_max_length = True,truncation=True)\n",
"dbert_inp"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[101,\n",
" 18414,\n",
" 17583,\n",
" 2391,\n",
" 4689,\n",
" 2800,\n",
" 11829,\n",
" 2483,\n",
" 2307,\n",
" 2088,\n",
" 28305,\n",
" 25022,\n",
" 2638,\n",
" 2288,\n",
" 26297,\n",
" 28194,\n",
" 102,\n",
" 0,\n",
" 0,\n",
" 0]"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dbert_inp['input_ids']"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> DistilBERT model output: Give input_ids and the attention_mask obtained from the tokenizer. The output will be a tuple of the size (1,max_len,768)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tuple,\n",
" (<tf.Tensor: shape=(1, 20, 768), dtype=float32, numpy=\n",
" array([[[ 0.603038 , -0.87843955, -0.27702317, ..., 0.34013888,\n",
" -0.31951576, -0.02768148],\n",
" [ 0.8059453 , -1.2426811 , -0.3692848 , ..., -0.00915013,\n",
" -0.2044661 , -0.12683335],\n",
" [ 0.7864537 , -0.9070081 , -0.44475678, ..., -0.00204397,\n",
" -0.31890398, -0.23745532],\n",
" ...,\n",
" [ 0.56349653, -1.0353185 , -0.26982975, ..., 0.37219822,\n",
" -0.30490598, -0.09034443],\n",
" [ 0.5596978 , -1.0259491 , -0.29068822, ..., 0.33829993,\n",
" -0.29505792, -0.1154284 ],\n",
" [ 0.70843756, -0.9754098 , -0.21341297, ..., 0.42030877,\n",
" -0.3028642 , -0.06653835]]], dtype=float32)>,))"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"id_inp=np.asarray(dbert_inp['input_ids'])\n",
"mask_inp=np.asarray(dbert_inp['attention_mask'])\n",
"out=dbert_model([id_inp.reshape(1,-1),mask_inp.reshape(1,-1)])\n",
"type(out),out"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> Obtain the embeddings of a sentence from the output"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<tf.Tensor: shape=(1, 768), dtype=float32, numpy=\n",
"array([[ 6.03038013e-01, -8.78439546e-01, -2.77023166e-01,\n",
" -7.81932235e-01, -2.24032372e-01, 1.76004171e-01,\n",
" -4.21196282e-01, -8.67987514e-01, 3.26672882e-01,\n",
" -6.30099177e-01, -3.55296314e-01, 3.10974181e-01,\n",
" -2.00993031e-01, 4.11007613e-01, -5.88603094e-02,\n",
" -1.61547720e-01, 3.93366128e-01, -2.55164415e-01,\n",
" -3.01635325e-01, 7.90954411e-01, 1.28573322e+00,\n",
" -6.68632388e-01, 2.54754089e-02, -2.28700653e-01,\n",
" -3.08614343e-01, 6.92516863e-01, -6.13005906e-02,\n",
" -2.24031866e-01, 3.97517495e-02, -1.10083513e-01,\n",
" -1.43123820e-01, -1.92341004e-02, -3.88225764e-02,\n",
" -4.60599959e-01, 1.04272123e-02, -3.33542168e-01,\n",
" 3.51539880e-01, 6.58564091e-01, -3.37058961e-01,\n",
" 1.02101535e-01, 9.40290630e-01, -3.21830750e-01,\n",
" -7.61335313e-01, 8.10378850e-01, -3.19713771e-01,\n",
" 3.22573185e-01, -1.69227576e+00, -4.91463065e-01,\n",
" -5.22750378e-01, 1.04521942e+00, 8.33612621e-01,\n",
" -1.85048893e-01, 6.69279337e-01, -6.02432609e-01,\n",
" 1.02250767e+00, 3.21514271e-02, -3.77190381e-01,\n",
" -1.27004194e+00, -2.17145652e-01, 2.42176414e-01,\n",
" 4.87879634e-01, 6.25644565e-01, 3.63948584e-01,\n",
" 6.12887181e-02, 1.08995032e+00, 2.86390800e-02,\n",
" -8.73474926e-02, 6.00923300e-02, 4.68736678e-01,\n",
" -1.06846416e+00, -6.91753030e-01, -1.20683730e+00,\n",
" 4.08894792e-02, 3.62308502e-01, -6.27320826e-01,\n",
" -1.19204938e+00, -3.80429924e-01, -3.61814857e-01,\n",
" -8.59443903e-01, 1.23737764e+00, 7.58002162e-01,\n",
" 1.10277310e-01, 1.26502132e+00, 1.60136008e+00,\n",
" 9.01422918e-01, -1.74658656e-01, 4.49645668e-01,\n",
" 4.34981465e-01, -5.68648398e-01, 6.84456527e-03,\n",
" -2.72482574e-01, -4.04388070e-01, -5.32054305e-01,\n",
" -6.13708086e-02, -7.89064243e-02, -2.93384135e-01,\n",
" 6.16099477e-01, 6.37774616e-02, 9.25561249e-01,\n",
" 4.33326215e-02, -1.06491053e+00, -3.66283581e-02,\n",
" -7.88599432e-01, -1.87204826e+00, -1.30752528e+00,\n",
" -2.77024359e-01, -7.01585293e-01, 5.99193215e-01,\n",
" -3.22149634e-01, -1.21525347e-01, -5.80495358e-01,\n",
" 1.33760840e-01, 2.68160999e-01, -2.68748909e-01,\n",
" 6.23604298e-01, -6.32791817e-01, -7.28992522e-01,\n",
" 2.73577273e-02, 4.96709526e-01, 3.28407705e-01,\n",
" 8.07063654e-02, 2.27984250e-01, 8.31354499e-01,\n",
" 1.64115340e-01, 6.46362841e-01, -7.88584769e-01,\n",
" 3.70997488e-01, 6.79060400e-01, -2.46167243e-01,\n",
" -2.42816746e-01, -4.94008660e-01, -2.26722956e-01,\n",
" -5.70130825e-01, -1.65110573e-01, -5.75526536e-01,\n",
" 7.95722753e-02, -4.43579674e-01, -1.80360958e-01,\n",
" 1.86605558e-01, 4.50560421e-01, 3.12413186e-01,\n",
" 5.83738387e-02, 4.98624407e-02, 9.47658896e-01,\n",
" 3.11616212e-01, -1.13319004e+00, -2.36583278e-01,\n",
" -4.60475087e-02, -2.79296875e-01, 1.16155052e+00,\n",
" 1.13577819e+00, 2.47922927e-01, -5.37598372e-01,\n",
" -1.05224144e+00, -4.29004937e-01, 6.94883466e-01,\n",
" 1.06838191e+00, 8.06059062e-01, -1.70146620e+00,\n",
" -3.60849142e-01, 5.34435272e-01, -4.29137468e-01,\n",
" -2.14618176e-01, -2.00021639e-01, -2.47490853e-01,\n",
" 4.15575691e-03, 1.28034502e-01, 2.79917419e-01,\n",
" -1.76520526e-01, 5.16117752e-01, 2.68898427e-01,\n",
" -5.98503165e-02, 6.01581454e-01, -3.77955884e-01,\n",
" -8.37433875e-01, 4.37795997e-01, 1.17776908e-01,\n",
" -4.19142731e-02, -5.80386877e-01, -7.86747098e-01,\n",
" -3.46670240e-01, -7.29465723e-01, 1.39434546e-01,\n",
" 1.19477224e+00, 1.47893623e-01, -3.96308541e-01,\n",
" 4.87730384e-01, -6.10263884e-01, 8.71831775e-01,\n",
" -9.81617346e-03, 2.12026909e-01, -3.35303277e-01,\n",
" -8.13672662e-01, 7.47541428e-01, 8.79059434e-01,\n",
" -2.73934007e-01, -2.70005614e-01, -3.04660082e-01,\n",
" 3.18873078e-01, -1.49186540e+00, 1.03478861e+00,\n",
" -1.45045981e-01, -4.47920859e-01, -2.01746285e-01,\n",
" -1.47374547e+00, 1.21385062e+00, 7.25501895e-01,\n",
" 2.03315675e-01, 2.11808294e-01, 3.08827817e-01,\n",
" 1.34706154e-01, 1.19417533e-01, -7.83641115e-02,\n",
" 1.34895757e-01, -1.85459405e-01, 1.20549336e-01,\n",
" 5.11147022e-01, 1.37983710e-01, -6.73445314e-03,\n",
" 5.61571956e-01, -6.91728950e-01, -2.88661756e-02,\n",
" -1.12136340e+00, 1.59873009e-01, -3.88660133e-01,\n",
" -3.70979011e-01, -7.27081776e-01, -2.36260980e-01,\n",
" 9.80740562e-02, 6.62679374e-01, 1.05562341e+00,\n",
" -4.23619747e-01, 1.09879464e-01, -5.76181471e-01,\n",
" 7.92896569e-01, 2.62166023e-01, 3.22766393e-01,\n",
" -6.80434167e-01, -1.11840799e-01, 1.39813051e-01,\n",
" 1.53463781e-01, 1.82946876e-01, 1.98349118e-01,\n",
" -5.28053880e-01, -6.28764212e-01, 3.03985495e-02,\n",
" 1.16870558e+00, -5.44939816e-01, 8.94414067e-01,\n",
" -2.67008662e-01, -2.51927316e-01, 4.20000851e-01,\n",
" -1.69454291e-01, 3.91806483e-01, 3.12675476e-01,\n",
" -4.41523910e-01, -1.50542736e+00, 6.69851303e-01,\n",
" -3.85377198e-01, 1.77042544e-01, -6.95514083e-02,\n",
" 3.35636027e-02, -4.65059161e-01, 4.26401466e-01,\n",
" 1.28072667e+00, -1.17895949e+00, 2.75932938e-01,\n",
" -4.06650305e-01, 1.14857435e+00, 5.15444219e-01,\n",
" 8.97431135e-01, 3.90559286e-01, 1.76408380e-01,\n",
" -1.68391034e-01, -1.48721266e+00, -6.39496863e-01,\n",
" 7.78892159e-01, 1.76998663e+00, 4.14469659e-01,\n",
" -9.89907026e-01, -9.60321426e-02, 1.21508919e-01,\n",
" 7.23461628e-01, -4.72908318e-01, -5.74852675e-02,\n",
" 6.02047682e-01, -9.90738034e-01, -4.86421496e-01,\n",
" -7.09693193e-01, 1.44216895e-01, 1.11436568e-01,\n",
" -4.99891862e-03, -6.46592140e-01, 3.29288155e-01,\n",
" -8.12867165e-01, -9.12498236e-02, -4.20294642e-01,\n",
" 2.99679488e-01, -2.30724383e-02, -5.27692735e-01,\n",
" 7.12708592e-01, -6.28750801e-01, -2.01221645e-01,\n",
" -5.55100963e-02, 2.07774356e-01, -2.77632654e-01,\n",
" -3.42278659e-01, -5.55952311e-01, -4.04374212e-01,\n",
" -3.20272803e-01, 2.77675450e-01, 4.52805966e-01,\n",
" 9.71656621e-01, 1.38616636e-02, 5.07998049e-01,\n",
" 3.83135080e-02, 4.17569935e-01, 5.26914179e-01,\n",
" 5.75926542e-01, 5.45349240e-01, 5.04254937e-01,\n",
" -3.00031126e-01, 1.08488023e+00, -3.27659726e-01,\n",
" 2.18448006e-02, -2.59728819e-01, 5.14609218e-01,\n",
" -4.06661555e-02, 4.28912103e-01, -4.16806132e-01,\n",
" -2.68253148e-01, -6.35986328e-02, 1.14468940e-01,\n",
" 6.01160347e-01, -8.13216627e-01, -1.65143609e-01,\n",
" -5.09384930e-01, -9.67750624e-02, -7.76142180e-01,\n",
" -5.34972608e-01, 9.33849096e-01, 1.36672229e-01,\n",
" -1.03093612e+00, -2.15662435e-01, -9.19019938e-01,\n",
" -9.32630450e-02, -3.22765499e-01, -1.20489919e+00,\n",
" 1.18114913e+00, -1.28427184e+00, 5.03147900e-01,\n",
" -2.58777589e-01, -4.45724726e-01, -9.47305143e-01,\n",
" -7.83599555e-01, -7.74042010e-01, 5.60701251e-01,\n",
" 3.75353009e-01, -4.40059602e-01, 6.83822513e-01,\n",
" -1.99001461e-01, -4.47226465e-01, -7.66293466e-01,\n",
" -1.76975593e-01, 4.13692445e-01, 3.47849429e-01,\n",
" 1.44688523e+00, -2.16559142e-01, 6.83537424e-01,\n",
" -8.69374573e-01, -6.56416774e-01, -1.01821315e+00,\n",
" 1.09332450e-01, -4.03954327e-01, -5.10778189e-01,\n",
" -5.83017170e-01, 7.74493515e-01, 6.33757770e-01,\n",
" -8.02680664e-03, -8.48268509e-01, 4.08677161e-01,\n",
" -3.87468368e-01, 7.57212162e-01, 2.29065239e-01,\n",
" -1.73584968e-01, -1.52334738e+00, 5.79953492e-01,\n",
" -9.02318180e-01, -3.37452501e-01, 3.04758400e-01,\n",
" 5.04641056e-01, -6.36874914e-01, -1.76794641e-02,\n",
" 1.18659437e+00, -6.05956316e-02, -3.02021772e-01,\n",
" -1.31736302e+00, 1.21303213e+00, 2.89690882e-01,\n",
" 7.71721244e-01, 6.06914997e-01, 2.22801834e-01,\n",
" -3.41530979e-01, -4.03286338e-01, 7.29780495e-01,\n",
" -1.21780179e-01, 1.12891698e+00, -8.17912996e-01,\n",
" -3.65746349e-01, 2.98856735e-01, -3.72117132e-01,\n",
" 6.81966901e-01, -5.58889031e-01, 3.17049235e-01,\n",
" 1.60192147e-01, -1.00157008e-01, 3.67980212e-01,\n",
" 3.83160487e-02, 3.78825784e-01, -9.74659324e-01,\n",
" -6.06691658e-01, 7.06549406e-01, -6.96622074e-01,\n",
" -2.89220780e-01, -7.81294048e-01, -2.24434718e-01,\n",
" -8.35678577e-02, -8.84586811e-01, 2.36391276e-01,\n",
" 5.21456659e-01, -5.31575918e-01, -8.08691382e-01,\n",
" 4.28600848e-01, 1.40789270e-01, 3.90130728e-02,\n",
" 7.63179004e-01, -1.53729856e+00, 6.73635483e-01,\n",
" -5.76094151e-01, -3.34772259e-01, -7.89971530e-01,\n",
" -4.99642402e-01, -2.58468449e-01, -6.15976810e-01,\n",
" -5.56715906e-01, 7.36451983e-01, -1.22028661e+00,\n",
" 2.75421858e-01, -4.09457117e-01, -7.62330413e-01,\n",
" 6.80023193e-01, 2.03611776e-02, 5.07364869e-01,\n",
" 1.27056861e+00, -3.35276932e-01, -5.38022101e-01,\n",
" 2.00237125e-01, -7.36508250e-01, 8.35546970e-01,\n",
" -5.83897352e-01, 2.31009647e-02, 4.94249910e-02,\n",
" -9.12698090e-01, -2.70418227e-01, -3.13920200e-01,\n",
" -6.72133386e-01, -1.10715842e+00, -1.43072093e+00,\n",
" 3.41811419e-01, 4.13841307e-02, -9.00741518e-01,\n",
" 2.59927660e-01, 1.76744908e-01, 4.92825247e-02,\n",
" -7.09833622e-01, -4.62444961e-01, -2.29689404e-01,\n",
" -4.23059464e-01, 4.64624539e-02, 9.19562221e-01,\n",
" 6.43446088e-01, 1.50866032e-01, -1.66563189e+00,\n",
" 6.91318095e-01, -2.54390210e-01, 6.44376516e-01,\n",
" 7.31475532e-01, 3.74760360e-01, 3.43573153e-01,\n",
" 4.48913842e-01, -1.03410316e+00, 1.37455910e-01,\n",
" 1.48356566e-03, -4.56609726e-01, -1.90604627e-01,\n",
" 2.77024180e-01, -9.05576706e-01, 2.20442131e-01,\n",
" -2.05784701e-02, 2.81469673e-01, -3.19618732e-01,\n",
" 3.97094548e-01, -1.89189211e-01, 5.80150306e-01,\n",
" -4.99626845e-01, -1.28104687e-01, 2.23559529e-01,\n",
" 2.96906561e-01, -4.35585171e-01, 8.14894557e-01,\n",
" -3.42589885e-01, -9.68980342e-02, 6.45970702e-01,\n",
" 4.59946066e-01, -9.86011326e-03, -6.75605983e-03,\n",
" -1.09264135e-01, -1.32763779e+00, 9.20030698e-02,\n",
" -1.35417238e-01, -1.02851018e-02, 3.71368259e-01,\n",
" -3.32512289e-01, -9.43855047e-01, 1.03147733e+00,\n",
" 3.85080457e-01, 3.07875872e-01, -3.63639414e-01,\n",
" -4.54725444e-01, 3.66942808e-02, -1.10807967e+00,\n",
" 5.16120484e-03, -6.40997112e-01, -3.77940565e-01,\n",
" -9.61077511e-01, -1.02243640e-01, 1.07516930e-01,\n",
" 8.06317091e-01, -6.29772842e-01, 6.92058444e-01,\n",
" -6.74996018e-01, 9.98790860e-01, 4.13442671e-01,\n",
" -5.87501526e-01, 1.34830284e+00, -2.53247619e-01,\n",
" 6.57434613e-02, -3.44445407e-01, -2.74323523e-01,\n",
" 9.41940308e-01, -8.51385117e-01, -7.17238486e-01,\n",
" 2.23294750e-01, 3.42073947e-01, -7.00151563e-01,\n",
" -7.35350788e-01, 4.51216787e-01, 1.06338024e+00,\n",
" -4.94434118e-01, 4.74332094e-01, 5.41221023e-01,\n",
" -4.70786572e-01, -1.10761851e-01, 1.74335867e-01,\n",
" -7.69726336e-01, 1.06428973e-02, 6.61674142e-01,\n",
" 2.25208431e-01, -1.32597953e-01, 5.16738534e-01,\n",
" -2.85375774e-01, 4.86573756e-01, -5.08984864e-01,\n",
" 5.65178514e-01, 5.64042807e-01, -1.42574728e-01,\n",
" -1.37143552e-01, -8.83973062e-01, 1.59324214e-01,\n",
" -1.37306765e-01, -5.55238128e-01, 4.68822509e-01,\n",
" 1.06851363e+00, -1.13501048e+00, -2.46283293e-01,\n",
" -1.41044967e-02, -4.88164634e-01, -5.61163127e-01,\n",
" -1.27986848e-01, 6.22749329e-01, 9.46501136e-01,\n",
" -2.97750253e-02, 2.37387717e-01, 8.73841271e-02,\n",
" 8.56920302e-01, -1.22735344e-01, -6.47104919e-01,\n",
" 9.68198419e-01, 6.32964194e-01, 3.04035604e-01,\n",
" -1.83739662e-01, 9.01215851e-01, 2.74185807e-01,\n",
" 1.16063453e-01, 1.16409987e-01, 4.41886872e-01,\n",
" -5.09184599e-01, -4.34587821e-02, -1.04615545e+00,\n",
" 1.34349144e+00, 3.28240216e-01, 5.11831164e-01,\n",
" -4.97375309e-01, -4.35965002e-01, 9.09861699e-02,\n",
" 3.75274777e-01, -1.38322973e+00, 6.86389983e-01,\n",
" 8.08061182e-01, -2.16697212e-02, -2.59509385e-01,\n",
" 1.10440838e+00, -4.42643166e-01, 7.78268814e-01,\n",
" -4.27443907e-02, -8.53625000e-01, 5.72073869e-02,\n",
" -4.17862535e-01, -3.06892216e-01, -5.07756591e-01,\n",
" 7.20375776e-01, -2.29387477e-01, 3.25722188e-01,\n",
" 2.70283282e-01, -3.94067764e-01, 1.71590179e-01,\n",
" 3.10447693e-01, 8.45310748e-01, 1.20628941e+00,\n",
" 5.81624731e-02, -5.00897706e-01, 2.73192406e-01,\n",
" 4.12979841e-01, -4.80810016e-01, -7.45794326e-02,\n",
" 4.46741641e-01, 5.07134259e-01, -4.63269413e-01,\n",
" 3.69267672e-01, -3.71461540e-01, -5.99944293e-01,\n",
" 2.54450083e-01, 7.62756526e-01, 1.22809298e-02,\n",
" 3.41042191e-01, 1.26630276e-01, 1.07698381e+00,\n",
" -3.90054017e-01, 9.72529575e-02, 4.55065370e-01,\n",
" 5.93872488e-01, 3.04408252e-01, -3.05748954e-02,\n",
" -7.07344636e-02, 1.42908096e-01, 6.98127389e-01,\n",
" 4.15675342e-01, 2.50390947e-01, -8.83328468e-02,\n",
" -6.90103173e-01, -2.74849921e-01, -4.97512460e-01,\n",
" -7.92889237e-01, -1.72456503e-01, -5.17004192e-01,\n",
" 9.61471051e-02, 6.25220612e-02, -5.31899214e-01,\n",
" 1.19089156e-01, -2.03830123e-01, 2.17729837e-01,\n",
" 7.03398347e-01, 4.58590627e-01, 3.45895708e-01,\n",
" 4.89512563e-01, 2.47522652e-01, -3.37672770e-01,\n",
" -7.65977800e-01, 4.40406382e-01, -1.51206374e-01,\n",
" -4.22124416e-01, -7.61940777e-01, 5.62069476e-01,\n",
" 6.65147305e-01, -1.48630470e-01, 3.96794468e-01,\n",
" -1.01848531e+00, 5.50713062e-01, -1.37245059e+00,\n",
" -3.25246781e-01, -6.54472828e-01, -1.57692814e+00,\n",
" 6.92537487e-01, -4.11718190e-01, 1.88343495e-01,\n",
" 7.11880922e-01, 2.75095046e-01, 1.39857098e-01,\n",
" 2.43308246e-01, 1.04106474e+00, -6.22403979e-01,\n",
" 4.39274341e-01, -9.14811045e-02, 6.97045982e-01,\n",
" -1.26863778e+00, 2.13712752e-01, 5.92919469e-01,\n",
" 2.12821156e-01, 2.60004923e-02, 2.71382332e-01,\n",
" -6.72967792e-01, 5.48906744e-01, -6.01199716e-02,\n",
" -1.03580225e+00, -2.12678000e-01, 3.37710142e-01,\n",
" -4.56657797e-01, 1.05779862e+00, -5.14126778e-01,\n",
" 5.96502542e-01, -3.85384947e-01, 2.41810536e+00,\n",
" 1.47140503e-01, -4.70191300e-01, 2.98708260e-01,\n",
" 7.87923813e-01, -2.47554213e-01, 4.51669917e-02,\n",
" -9.01018620e-01, 9.96429980e-01, -7.16759324e-01,\n",
" -2.05091119e-01, -6.52647495e-01, 6.88338876e-02,\n",
" -1.41500607e-02, -9.52726841e-01, -5.07680476e-01,\n",
" 6.55689001e-01, 8.05295110e-01, 5.53219579e-05,\n",
" 4.13542986e-01, 6.86698437e-01, -1.05570567e+00,\n",
" -1.81498122e+00, 5.10226429e-01, 9.01936412e-01,\n",
" 1.02417707e+00, -3.27444732e-01, 5.75429916e-01,\n",
" -3.47295731e-01, -2.70201862e-01, 4.71759140e-01,\n",
" -1.80455774e-01, -3.56086224e-01, -1.00847639e-01,\n",
" 3.40138882e-01, -3.19515765e-01, -2.76814811e-02]], dtype=float32)>"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"out[0][:,0,:]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> Decode the original sentence from the tokenizer "
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'[CLS] jurong point crazy available bugis great world buffet cine got amore wat [SEP] [PAD] [PAD] [PAD]'"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dbert_tokenizer.decode(dbert_inp['input_ids'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Create a basic NN model using DistilBERT embeddings to get the predictions ###"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"def create_model():\n",
" inps = Input(shape = (max_len,), dtype='int64')\n",
" masks= Input(shape = (max_len,), dtype='int64')\n",
" dbert_layer = dbert_model(inps, attention_mask=masks)[0][:,0,:]\n",
" dense = Dense(512,activation='relu',kernel_regularizer=regularizers.l2(0.01))(dbert_layer)\n",
" dropout= Dropout(0.5)(dense)\n",
" pred = Dense(num_classes, activation='softmax',kernel_regularizer=regularizers.l2(0.01))(dropout)\n",
" model = tf.keras.Model(inputs=[inps,masks], outputs=pred)\n",
" print(model.summary())\n",
" return model "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> Feel free to add more Dense and Dropout layers with variable units and the regularizers"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:AutoGraph could not transform <bound method TFDistilBertModel.call of <transformers.modeling_tf_distilbert.TFDistilBertModel object at 0x7fdff6ca9490>> and will run it as-is.\n",
"Model: \"functional_1\"\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"input_1 (InputLayer) [(None, 32)] 0 \n",
"__________________________________________________________________________________________________\n",
"input_2 (InputLayer) [(None, 32)] 0 \n",
"__________________________________________________________________________________________________\n",
"tf_distil_bert_model_4 (TFDisti ((None, 32, 768),) 66362880 input_1[0][0] \n",
" input_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"tf_op_layer_strided_slice (Tens [(None, 768)] 0 tf_distil_bert_model_4[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense (Dense) (None, 512) 393728 tf_op_layer_strided_slice[0][0] \n",
"__________________________________________________________________________________________________\n",
"dropout_95 (Dropout) (None, 512) 0 dense[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_1 (Dense) (None, 2) 1026 dropout_95[0][0] \n",
"==================================================================================================\n",
"Total params: 66,757,634\n",
"Trainable params: 66,757,634\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n",
"None\n"
]
}
],
"source": [
"model=create_model()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> Prepare the model input "
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"input_ids=[]\n",
"attention_masks=[]\n",
"\n",
"for sent in sentences:\n",
" dbert_inps=dbert_tokenizer.encode_plus(sent,add_special_tokens = True,max_length =max_len,pad_to_max_length = True,return_attention_mask = True,truncation=True)\n",
" input_ids.append(dbert_inps['input_ids'])\n",
" attention_masks.append(dbert_inps['attention_mask'])\n",
"\n",
"input_ids=np.asarray(input_ids)\n",
"attention_masks=np.array(attention_masks)\n",
"labels=np.array(labels)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(5572, 5572, 5572)"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(input_ids),len(attention_masks),len(labels)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> Save the model input in the pickle files to use it later without performing the above steps"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Preparing the pickle file.....\n",
"Pickle files saved as ./data/dbert_inp.pkl ./data/dbert_mask.pkl ./data/dbert_label.pkl\n"
]
}
],
"source": [
"print('Preparing the pickle file.....')\n",
"\n",
"pickle_inp_path='./data/dbert_inp.pkl'\n",
"pickle_mask_path='./data/dbert_mask.pkl'\n",
"pickle_label_path='./data/dbert_label.pkl'\n",
"\n",
"pickle.dump((input_ids),open(pickle_inp_path,'wb'))\n",
"pickle.dump((attention_masks),open(pickle_mask_path,'wb'))\n",
"pickle.dump((labels),open(pickle_label_path,'wb'))\n",
"\n",
"\n",
"print('Pickle files saved as ',pickle_inp_path,pickle_mask_path,pickle_label_path)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading the saved pickle files..\n",
"Input shape (5572, 32) Attention mask shape (5572, 32) Input label shape (5572,)\n"
]
}
],
"source": [
"print('Loading the saved pickle files..')\n",
"\n",
"input_ids=pickle.load(open(pickle_inp_path, 'rb'))\n",
"attention_masks=pickle.load(open(pickle_mask_path, 'rb'))\n",
"labels=pickle.load(open(pickle_label_path, 'rb'))\n",
"\n",
"print('Input shape {} Attention mask shape {} Input label shape {}'.format(input_ids.shape,attention_masks.shape,labels.shape))"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"label_class_dict={0:'ham',1:'spam'}\n",
"target_names=label_class_dict.values()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> Train Test split and setting up the loss function, accuracy and optimizer for the model. "
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train inp shape (4457, 32) Val input shape (1115, 32)\n",
"Train label shape (4457,) Val label shape (1115,)\n",
"Train attention mask shape (4457, 32) Val attention mask shape (1115, 32)\n"
]
}
],
"source": [
"train_inp,val_inp,train_label,val_label,train_mask,val_mask=train_test_split(input_ids,labels,attention_masks,test_size=0.2)\n",
"\n",
"print('Train inp shape {} Val input shape {}\\nTrain label shape {} Val label shape {}\\nTrain attention mask shape {} Val attention mask shape {}'.format(train_inp.shape,val_inp.shape,train_label.shape,val_label.shape,train_mask.shape,val_mask.shape))\n",
"\n",
"\n",
"log_dir='dbert_model'\n",
"model_save_path='./dbert_model.h5'\n",
"\n",
"callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path,save_weights_only=True,monitor='val_loss',mode='min',save_best_only=True),keras.callbacks.TensorBoard(log_dir=log_dir)]\n",
"\n",
"loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
"metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')\n",
"optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5)\n",
"\n",
"model.compile(loss=loss,optimizer=optimizer, metrics=[metric])"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"callbacks= [tf.keras.callbacks.ModelCheckpoint(filepath=model_save_path,save_weights_only=True,monitor='val_loss',mode='min',save_best_only=True),keras.callbacks.TensorBoard(log_dir=log_dir)]\n",
"model.compile(loss=loss,optimizer=optimizer, metrics=[metric])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Training ###"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/5\n",
"1/279 [..............................] - ETA: 0s - loss: 6.7663 - accuracy: 0.8750\n",
"Instructions for updating:\n",
"use `tf.profiler.experimental.stop` instead.\n",
"279/279 [==============================] - ETA: 0s - loss: 5.6354 - accuracy: 0.9435\n",
"279/279 [==============================] - 251s 898ms/step - loss: 5.6354 - accuracy: 0.9435 - val_loss: 4.7500 - val_accuracy: 0.9749\n",
"Epoch 2/5\n",
"279/279 [==============================] - 244s 875ms/step - loss: 4.0501 - accuracy: 0.9749 - val_loss: 3.4032 - val_accuracy: 0.9821\n",
"Epoch 3/5\n",
"279/279 [==============================] - 247s 884ms/step - loss: 2.8908 - accuracy: 0.9796 - val_loss: 2.4331 - val_accuracy: 0.9704\n",
"Epoch 4/5\n",
"279/279 [==============================] - 249s 894ms/step - loss: 2.0531 - accuracy: 0.9818 - val_loss: 1.7228 - val_accuracy: 0.9794\n",
"Epoch 5/5\n",
"279/279 [==============================] - 244s 874ms/step - loss: 1.4562 - accuracy: 0.9877 - val_loss: 1.2264 - val_accuracy: 0.9857\n"
]
}
],
"source": [
"history=model.fit([train_inp,train_mask],train_label,batch_size=16,epochs=5,validation_data=([val_inp,val_mask],val_label),callbacks=callbacks)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Tensorboard visualization (Training-Testing curve) ###"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"%load_ext tensorboard"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <iframe id=\"tensorboard-frame-1d79222d114abbf0\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
" </iframe>\n",
" <script>\n",
" (function() {\n",
" const frame = document.getElementById(\"tensorboard-frame-1d79222d114abbf0\");\n",
" const url = new URL(\"/\", window.location);\n",
" const port = 6007;\n",
" if (port) {\n",
" url.port = port;\n",
" }\n",
" frame.src = url;\n",
" })();\n",
" </script>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%tensorboard --logdir {log_dir}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"< Increase the number of epochs in order to decrease the loss further"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Use the saved model for predictions and calculating the evaluation metrics ###"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"functional_5\"\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"input_5 (InputLayer) [(None, 32)] 0 \n",
"__________________________________________________________________________________________________\n",
"input_6 (InputLayer) [(None, 32)] 0 \n",
"__________________________________________________________________________________________________\n",
"tf_distil_bert_model_4 (TFDisti ((None, 32, 768),) 66362880 input_5[0][0] \n",
" input_6[0][0] \n",
"__________________________________________________________________________________________________\n",
"tf_op_layer_strided_slice_2 (Te [(None, 768)] 0 tf_distil_bert_model_4[2][0] \n",
"__________________________________________________________________________________________________\n",
"dense_4 (Dense) (None, 512) 393728 tf_op_layer_strided_slice_2[0][0]\n",
"__________________________________________________________________________________________________\n",
"dropout_97 (Dropout) (None, 512) 0 dense_4[0][0] \n",
"__________________________________________________________________________________________________\n",
"dense_5 (Dense) (None, 2) 1026 dropout_97[0][0] \n",
"==================================================================================================\n",
"Total params: 66,757,634\n",
"Trainable params: 66,757,634\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n",
"None\n"
]
}
],
"source": [
"trained_model = create_model()\n",
"trained_model.compile(loss=loss,optimizer=optimizer, metrics=[metric])\n",
"trained_model.load_weights(model_save_path)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"WARNING:tensorflow:AutoGraph could not transform <function Model.make_predict_function.<locals>.predict_function at 0x7fdf88126e60> and will run it as-is.\n"
]
},
{
"data": {
"text/plain": [
"0.9424460431654677"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preds = trained_model.predict([val_inp,val_mask],batch_size=16)\n",
"pred_labels = preds.argmax(axis=1)\n",
"f1 = f1_score(val_label,pred_labels)\n",
"f1"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"F1 score 0.9424460431654677\n",
"Classification Report\n",
" precision recall f1-score support\n",
"\n",
" ham 0.99 0.99 0.99 973\n",
" spam 0.96 0.92 0.94 142\n",
"\n",
" accuracy 0.99 1115\n",
" macro avg 0.98 0.96 0.97 1115\n",
"weighted avg 0.99 0.99 0.99 1115\n",
"\n",
"Training and saving built model.....\n"
]
}
],
"source": [
"target_names=['ham','spam']\n",
"print('F1 score',f1)\n",
"print('Classification Report')\n",
"print(classification_report(val_label,pred_labels,target_names=target_names))\n",
"\n",
"print('Training and saving built model.....') "
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
@kalidoracle
Copy link

this error has been appeared when I run the code "
RuntimeError: Failed to import transformers.models.audio_spectrogram_transformer.feature_extraction_audio_spectrogram_transformer because of the following error (look up to see its traceback): partially initialized module 'torchaudio' has no attribute 'lib' (most likely due to a circular import)"

@rsuryae
Copy link

rsuryae commented Jan 14, 2024

Hi - Can u pl share the data file - spam.csv?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment