Skip to content

Instantly share code, notes, and snippets.

@lowener
Last active November 24, 2023 03:09
Show Gist options
  • Save lowener/cf2358ee1d595884b1292be0ad91c0d1 to your computer and use it in GitHub Desktop.
Save lowener/cf2358ee1d595884b1292be0ad91c0d1 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "aa11c40d",
"metadata": {},
"source": [
"# Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "48f4c264",
"metadata": {},
"outputs": [],
"source": [
"import cudf\n",
"import cupy as cp\n",
"import numpy as np\n",
"from cuml.preprocessing import LabelEncoder, OneHotEncoder\n",
"from cuml.naive_bayes import GaussianNB, MultinomialNB, CategoricalNB, BernoulliNB, ComplementNB\n",
"from cuml.cluster import KMeans\n",
"from sklearn.model_selection import train_test_split\n",
"from cuml.feature_extraction.text import TfidfVectorizer, HashingVectorizer, CountVectorizer\n",
"import math\n",
"import matplotlib.pyplot as plt\n",
"import itertools\n",
"from sklearn.naive_bayes import GaussianNB as GaussianNB_sk\n",
"from sklearn.naive_bayes import BernoulliNB as BernoulliNB_sk\n",
"from sklearn.naive_bayes import CategoricalNB as CategoricalNB_sk\n",
"from sklearn.naive_bayes import MultinomialNB as MultinomialNB_sk\n",
"from sklearn.naive_bayes import ComplementNB as ComplementNB_sk\n",
"\n",
"import numpy as np\n",
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"id": "60ed554e",
"metadata": {},
"source": [
"# Loading News Aggregator dataset \n",
"\n",
"The dataset is loaded with cudf. For more information on cudf see the documentation [here](https://docs.rapids.ai/api/cudf/stable).\n",
"\n",
"Then we check the class distribution and the sparsity of the data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "3eb0dba4",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>ID</th>\n",
" <th>TITLE</th>\n",
" <th>URL</th>\n",
" <th>PUBLISHER</th>\n",
" <th>CATEGORY</th>\n",
" <th>STORY</th>\n",
" <th>HOSTNAME</th>\n",
" <th>TIMESTAMP</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1</td>\n",
" <td>Fed official says weak data caused by weather,...</td>\n",
" <td>http://www.latimes.com/business/money/la-fi-mo...</td>\n",
" <td>Los Angeles Times</td>\n",
" <td>1</td>\n",
" <td>ddUyU0VZz0BRneMioxUPQVP6sIxvM</td>\n",
" <td>www.latimes.com</td>\n",
" <td>1394470370698</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>Fed's Charles Plosser sees high bar for change...</td>\n",
" <td>http://www.livemint.com/Politics/H2EvwJSK2VE6O...</td>\n",
" <td>Livemint</td>\n",
" <td>1</td>\n",
" <td>ddUyU0VZz0BRneMioxUPQVP6sIxvM</td>\n",
" <td>www.livemint.com</td>\n",
" <td>1394470371207</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>3</td>\n",
" <td>US open: Stocks fall after Fed official hints ...</td>\n",
" <td>http://www.ifamagazine.com/news/us-open-stocks...</td>\n",
" <td>IFA Magazine</td>\n",
" <td>1</td>\n",
" <td>ddUyU0VZz0BRneMioxUPQVP6sIxvM</td>\n",
" <td>www.ifamagazine.com</td>\n",
" <td>1394470371550</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" <td>Fed risks falling 'behind the curve', Charles ...</td>\n",
" <td>http://www.ifamagazine.com/news/fed-risks-fall...</td>\n",
" <td>IFA Magazine</td>\n",
" <td>1</td>\n",
" <td>ddUyU0VZz0BRneMioxUPQVP6sIxvM</td>\n",
" <td>www.ifamagazine.com</td>\n",
" <td>1394470371793</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5</td>\n",
" <td>Fed's Plosser: Nasty Weather Has Curbed Job Gr...</td>\n",
" <td>http://www.moneynews.com/Economy/federal-reser...</td>\n",
" <td>Moneynews</td>\n",
" <td>1</td>\n",
" <td>ddUyU0VZz0BRneMioxUPQVP6sIxvM</td>\n",
" <td>www.moneynews.com</td>\n",
" <td>1394470372027</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" ID TITLE \\\n",
"0 1 Fed official says weak data caused by weather,... \n",
"1 2 Fed's Charles Plosser sees high bar for change... \n",
"2 3 US open: Stocks fall after Fed official hints ... \n",
"3 4 Fed risks falling 'behind the curve', Charles ... \n",
"4 5 Fed's Plosser: Nasty Weather Has Curbed Job Gr... \n",
"\n",
" URL PUBLISHER \\\n",
"0 http://www.latimes.com/business/money/la-fi-mo... Los Angeles Times \n",
"1 http://www.livemint.com/Politics/H2EvwJSK2VE6O... Livemint \n",
"2 http://www.ifamagazine.com/news/us-open-stocks... IFA Magazine \n",
"3 http://www.ifamagazine.com/news/fed-risks-fall... IFA Magazine \n",
"4 http://www.moneynews.com/Economy/federal-reser... Moneynews \n",
"\n",
" CATEGORY STORY HOSTNAME TIMESTAMP \n",
"0 1 ddUyU0VZz0BRneMioxUPQVP6sIxvM www.latimes.com 1394470370698 \n",
"1 1 ddUyU0VZz0BRneMioxUPQVP6sIxvM www.livemint.com 1394470371207 \n",
"2 1 ddUyU0VZz0BRneMioxUPQVP6sIxvM www.ifamagazine.com 1394470371550 \n",
"3 1 ddUyU0VZz0BRneMioxUPQVP6sIxvM www.ifamagazine.com 1394470371793 \n",
"4 1 ddUyU0VZz0BRneMioxUPQVP6sIxvM www.moneynews.com 1394470372027 "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"path = '~/Downloads/uci-news-aggregator.csv'\n",
"dataset=cudf.read_csv(path,sep = \",\")\n",
"# business; technology; entertainment; health\n",
"dataset['CATEGORY'] = dataset.CATEGORY.map({ 'b': 1, 't': 2, 'e': 3, 'm': 4 })\n",
"dataset.head()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f8d63575",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3 152469\n",
"1 115967\n",
"2 108344\n",
"4 45639\n",
"Name: CATEGORY, dtype: int32"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Class distribution\n",
"dataset['CATEGORY'].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "34d6df57",
"metadata": {},
"outputs": [],
"source": [
"Y, X = dataset['CATEGORY'], dataset['TITLE']\n",
"le = LabelEncoder()\n",
"y = le.fit_transform(Y)\n",
"\n",
"X_train_text, X_test_text, y_train, y_test = train_test_split(X, y, random_state=1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "83d5a920",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.013490494373383653\n",
"0.013362668958684477\n"
]
}
],
"source": [
"vec = TfidfVectorizer(stop_words='english', ngram_range=(1,1))\n",
"x_train = vec.fit_transform(X_train_text)\n",
"x_test = vec.transform(X_test_text)\n",
"\n",
"# Print sparsity of train and test\n",
"print(x_train.nnz / (x_train.shape[0] * x_train.shape[1])*100)\n",
"print(x_test.nnz / (x_test.shape[0] * x_test.shape[1])*100)"
]
},
{
"cell_type": "markdown",
"id": "92903d9d",
"metadata": {
"tags": []
},
"source": [
"# Gaussian NB\n",
"\n",
"Transform the text through a TF-IDF vectorizer and iterate through the dataset to do multiple partial fits of Gaussian naive Bayes."
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1079a683",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 12.3 s, sys: 2.23 s, total: 14.5 s\n",
"Wall time: 22 s\n",
"0.8769999742507935\n",
"0.8840000033378601\n",
"0.878083348274231\n",
"0.8805833458900452\n",
"0.8756666779518127\n",
"0.8796666860580444\n",
"0.8786666393280029\n",
"0.8777499794960022\n",
"0.8823529481887817\n",
"CPU times: user 4.36 s, sys: 2.74 s, total: 7.1 s\n",
"Wall time: 22.8 s\n"
]
}
],
"source": [
"vec = TfidfVectorizer(stop_words='english', ngram_range=(1,3), min_df=5)\n",
"x_train = vec.fit_transform(X_train_text)\n",
"x_test = vec.transform(X_test_text)\n",
"\n",
"def dataset_traversal(X, Y, partial_function):\n",
" chunk_size = 12000\n",
" classes = cp.unique(Y)\n",
" lower = 0\n",
" for upper in iter(range(chunk_size, X.shape[0], chunk_size)):\n",
" partial_function(X[lower:upper], Y[lower:upper], classes)\n",
" lower = upper\n",
" partial_function(X[upper:], Y[upper:], classes)\n",
"\n",
"mnb = GaussianNB()\n",
"%time dataset_traversal(x_train,\\\n",
" y_train,\\\n",
" lambda x,y, c: mnb.partial_fit(x, y, c))\n",
"\n",
"%time dataset_traversal(x_test,\\\n",
" y_test,\\\n",
" lambda x, y, c: print(mnb.score(x, y)))\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "a0c9ccc9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2min 47s, sys: 1min 29s, total: 4min 17s\n",
"Wall time: 4min 17s\n",
"0.885\n",
"0.8736\n",
"0.8802\n",
"0.8828\n",
"0.8836\n",
"0.8738\n",
"0.8806\n",
"0.881\n",
"0.8832\n",
"0.8784\n",
"0.8714\n",
"0.879\n",
"0.8754\n",
"0.8782\n",
"0.8816\n",
"0.8844\n",
"0.875\n",
"0.8764\n",
"0.877\n",
"0.8864\n",
"0.8796\n",
"0.8842975206611571\n",
"CPU times: user 3min 8s, sys: 2min 7s, total: 5min 16s\n",
"Wall time: 5min 16s\n"
]
}
],
"source": [
"vec = TfidfVectorizer(stop_words='english', ngram_range=(1,3), min_df=5)\n",
"x_train = vec.fit_transform(X_train_text)\n",
"x_test = vec.transform(X_test_text)\n",
"x_train_np, x_test_np = x_train.get(), x_test.get()\n",
"y_train_np, y_test_np = y_train.to_numpy(), y_test.to_numpy()\n",
"\n",
"def dataset_traversal(X, Y, partial_function):\n",
" chunk_size = 5000\n",
" classes = np.unique(Y)\n",
" lower = 0\n",
" for upper in iter(range(chunk_size, X.shape[0], chunk_size)):\n",
" partial_function(X[lower:upper], Y[lower:upper], classes)\n",
" lower = upper\n",
" partial_function(X[upper:], Y[upper:], classes)\n",
"\n",
"mnb = GaussianNB_sk()\n",
"%time dataset_traversal(x_train_np,\\\n",
" y_train_np,\\\n",
" lambda x, y, c: mnb.partial_fit(x.toarray(), y, c))\n",
"\n",
"%time dataset_traversal(x_test_np,\\\n",
" y_test_np,\\\n",
" lambda x, y, c: print(mnb.score(x.toarray(), y)))\n"
]
},
{
"cell_type": "markdown",
"id": "5dea07cd",
"metadata": {
"jp-MarkdownHeadingCollapsed": true,
"tags": []
},
"source": [
"# Bernoulli + CountVectorizer\n",
"\n",
"\n",
"In the Bernoulli variant, the feature vector is binarized. That's why using a CountVectorizer transformer is useful: You're more interested in the presence of the word rather than it's frequency."
]
},
{
"cell_type": "code",
"execution_count": 241,
"id": "33c0cb40",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 44.4 ms, sys: 12.1 ms, total: 56.5 ms\n",
"Wall time: 56.5 ms\n",
"CPU times: user 14.9 ms, sys: 19.6 ms, total: 34.5 ms\n",
"Wall time: 34.2 ms\n"
]
},
{
"data": {
"text/plain": [
"0.8568723201751709"
]
},
"execution_count": 241,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vec = CountVectorizer(stop_words='english', binary=True, ngram_range=(1,3))\n",
"\n",
"x_train = vec.fit_transform(X_train_text)\n",
"x_test = vec.transform(X_test_text)\n",
"\n",
"bnb = BernoulliNB()\n",
"%time bnb.fit(x_train, y_train)\n",
"%time bnb.score(x_test, y_test)"
]
},
{
"cell_type": "code",
"execution_count": 247,
"id": "b735ce40",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 293 ms, sys: 72.1 ms, total: 365 ms\n",
"Wall time: 365 ms\n",
"CPU times: user 141 ms, sys: 90.9 ms, total: 232 ms\n",
"Wall time: 231 ms\n"
]
},
{
"data": {
"text/plain": [
"0.8568817764310402"
]
},
"execution_count": 247,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vec = CountVectorizer(stop_words='english', binary=True, ngram_range=(1,3))\n",
"x_train = vec.fit_transform(X_train_text)\n",
"x_test = vec.transform(X_test_text)\n",
"x_train_np, x_test_np = x_train.get(), x_test.get()\n",
"y_train_np, y_test_np = y_train.to_numpy(), y_test.to_numpy()\n",
"\n",
"bnb = BernoulliNB_sk()\n",
"%time bnb.fit(x_train_np, y_train_np)\n",
"%time bnb.score(x_test_np, y_test_np)"
]
},
{
"cell_type": "markdown",
"id": "5957db9b",
"metadata": {
"jp-MarkdownHeadingCollapsed": true,
"tags": []
},
"source": [
"# TF-IDF + Multinomial\n",
"\n",
"Transform the text through a TF-IDF vectorizer, and run a multinomial naive Bayes model."
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "637d0f2e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 55.4 ms, sys: 7.57 ms, total: 63 ms\n",
"Wall time: 63 ms\n",
"CPU times: user 20.3 ms, sys: 8.16 ms, total: 28.4 ms\n",
"Wall time: 28.2 ms\n"
]
},
{
"data": {
"text/plain": [
"0.9248046875"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vec = TfidfVectorizer(stop_words='english', ngram_range=(1,3))\n",
"x_train = vec.fit_transform(X_train_text)\n",
"x_test = vec.transform(X_test_text)\n",
"\n",
"mnb = MultinomialNB()\n",
"%time mnb.fit(x_train, y_train)\n",
"%time mnb.score(x_test, y_test)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "da09815a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 264 ms, sys: 67.6 ms, total: 332 ms\n",
"Wall time: 332 ms\n",
"CPU times: user 31.8 ms, sys: 27.9 ms, total: 59.7 ms\n",
"Wall time: 59.4 ms\n"
]
},
{
"data": {
"text/plain": [
"0.9248046967473131"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vec = TfidfVectorizer(stop_words='english', ngram_range=(1,3))\n",
"x_train = vec.fit_transform(X_train_text)\n",
"x_test = vec.transform(X_test_text)\n",
"x_train_np, x_test_np = x_train.get(), x_test.get()\n",
"y_train_np, y_test_np = y_train.to_numpy(), y_test.to_numpy()\n",
"\n",
"mnb = MultinomialNB_sk()\n",
"%time mnb.fit(x_train_np, y_train_np)\n",
"%time mnb.score(x_test_np, y_test_np)"
]
},
{
"cell_type": "markdown",
"id": "5ff606c1-a2c1-461f-9ef1-e912d5735c79",
"metadata": {},
"source": [
"# CountVectorizer + Complement\n",
"Complement naive Bayes models should be coupled with a CountVectorizer to have the best results."
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "767cd8ed-bd27-4e3a-9a6d-6c516bfef80c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3 0.360943\n",
"1 0.274531\n",
"2 0.256485\n",
"4 0.108042\n",
"Name: CATEGORY, dtype: float64"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# First let's visualize the class imbalance\n",
"\n",
"dataset['CATEGORY'].value_counts().to_pandas().plot(kind='bar', title='histogram of the class distributions')\n",
"dataset['CATEGORY'].value_counts() / len(dataset)"
]
},
{
"cell_type": "code",
"execution_count": 252,
"id": "65a127f2-be0d-4ab8-b87f-45aabbe4a7fa",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 56.8 ms, sys: 12.3 ms, total: 69.1 ms\n",
"Wall time: 69.2 ms\n",
"CPU times: user 22.4 ms, sys: 7.27 ms, total: 29.7 ms\n",
"Wall time: 29.8 ms\n"
]
},
{
"data": {
"text/plain": [
"0.9502959251403809"
]
},
"execution_count": 252,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vec = CountVectorizer(stop_words='english', ngram_range=(1,3))\n",
"x_train = vec.fit_transform(X_train_text)\n",
"x_test = vec.transform(X_test_text)\n",
"\n",
"cnb = ComplementNB()\n",
"%time cnb.fit(x_train, y_train)\n",
"%time cnb.score(x_test, y_test)"
]
},
{
"cell_type": "code",
"execution_count": 253,
"id": "73005bfc-6ffb-45af-9809-00d5345c597a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 67.5 ms, sys: 31.8 ms, total: 99.3 ms\n",
"Wall time: 99.5 ms\n",
"CPU times: user 26.6 ms, sys: 11.4 ms, total: 38 ms\n",
"Wall time: 37.7 ms\n"
]
},
{
"data": {
"text/plain": [
"0.9449836611747742"
]
},
"execution_count": 253,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vec = CountVectorizer(stop_words='english', ngram_range=(1,3))\n",
"x_train = vec.fit_transform(X_train_text)\n",
"x_test = vec.transform(X_test_text)\n",
"x_train_np, x_test_np = x_train.get(), x_test.get()\n",
"y_train_np, y_test_np = y_train.to_numpy(), y_test.to_numpy()\n",
"\n",
"cnb = ComplementNB_sk()\n",
"%time mnb.fit(x_train_np, y_train_np)\n",
"%time mnb.score(x_test_np, y_test_np)"
]
},
{
"cell_type": "markdown",
"id": "67f23185",
"metadata": {},
"source": [
"# Categorical\n",
"\n",
"To transform the text to categorical data, you can apply a clustering technique to merge the terms that are similar.\n",
"\n",
"To create these clusters, you could reuse a previously fitted naive Bayes model but just for the purpose of clustering those words."
]
},
{
"cell_type": "markdown",
"id": "aef23dff",
"metadata": {},
"source": [
"## Preprocessing"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "84ed876e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"count 1000.000000\n",
"mean 14.967000\n",
"std 19.519501\n",
"min 1.000000\n",
"25% 6.000000\n",
"50% 11.000000\n",
"75% 18.000000\n",
"max 254.000000\n",
"dtype: float64\n"
]
}
],
"source": [
"# First fit a TfIdf on the train dataset\n",
"tfidfvec = TfidfVectorizer(stop_words='english', min_df=10)\n",
"x_train = tfidfvec.fit_transform(X_train_text)\n",
"\n",
"# Fit a Multinomial on the TdIdf data\n",
"mnb = MultinomialNB().fit(x_train, y_train)\n",
"\n",
"# Use a KMeans algorithm to cluster on what the Multinomial NB learned of the TfIdf.\n",
"# This means that the words that contribute similarly to a category will be clustered together\n",
"km = KMeans(n_clusters=1000, random_state=1)\n",
"feature_to_cluster = km.fit_predict(mnb.feature_log_prob_.T)\n",
"feats2cluster = OneHotEncoder().fit_transform(feature_to_cluster)\n",
"\n",
"# Print statistics on the repartition of the words in the clusters\n",
"print(cudf.Series(feats2cluster.sum(0)[0]).describe())"
]
},
{
"cell_type": "markdown",
"id": "d04509d7",
"metadata": {},
"source": [
"Here each cluster holds in average around 15 words"
]
},
{
"cell_type": "code",
"execution_count": 225,
"id": "ada1038b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"47 117\n",
"1597 beam\n",
"2114 broadband\n",
"2435 carriers\n",
"2618 charter\n",
"3788 defective\n",
"4056 dire\n",
"4406 dual\n",
"5367 fixes\n",
"8365 materials\n",
"9072 networking\n",
"10900 recognition\n",
"11466 rollout\n",
"13666 tracker\n",
"14088 unveiling\n",
"Name: token, dtype: object\n",
"\n",
"\n",
"3293 core\n",
"3603 cyber\n",
"4751 enterprise\n",
"5719 gaming\n",
"8738 models\n",
"9801 pc\n",
"10074 platform\n",
"14338 virtual\n",
"Name: token, dtype: object\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAOfklEQVR4nO3db4zdVZ3H8fdnqcCKWcqfCalt3amx0RATF9IghM1mQ11X0VgeoMGYtWGb9AmuKCZadh+QfSaJESHZEBurixvjn0WyENZo3IIP9oHdbZXwr7CMKLRNgdEA7mqMNn73wT3VS22ZOzN3ZnrPvF/Jzfx+55w795w5zadnzv3d36SqkCT15Y9WugOSpPEz3CWpQ4a7JHXIcJekDhnuktShNSvdAYALL7ywpqenV7obkjRRDhw48NOqmjpZ3WkR7tPT0+zfv3+luyFJEyXJM6eqc1tGkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6dFp8QlWnn+ld/z6v9j/59HuWqCeSFsKVuyR1aOJX7q4wJekPuXKXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6NFK4J/l4kseSPJrkq0nOTrIpyb4kM0m+nuTM1vasdj7T6qeXdASSpD8wZ7gnWQ98FNhSVW8FzgCuA24FbquqNwEvAjvaU3YAL7by21o7SdIyGnVbZg3wx0nWAK8FjgJXAXe3+ruAa9rxtnZOq9+aJGPprSRpJHOGe1UdAT4DPMsg1F8GDgAvVdWx1uwwsL4drwcOtecea+0vOPH7JtmZZH+S/bOzs4sdhyRpyCjbMucxWI1vAl4PnAO8a7EvXFW7q2pLVW2Zmppa7LeTJA0ZZVvmHcCPq2q2qn4D3ANcCaxt2zQAG4Aj7fgIsBGg1Z8L/GysvZYkvapRwv1Z4PIkr21751uBx4EHgWtbm+3Ave34vnZOq3+gqmp8XZYkzWWUPfd9DN4Y/QHwSHvObuBTwE1JZhjsqe9pT9kDXNDKbwJ2LUG/JUmvYqS/oVpVtwC3nFD8NHDZSdr+Cnj/4ru2evh3YCWN28T/gezlYPhKmjTefkCSOmS4S1KH3JZRl9xK02rnyl2SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchb/mosvMWudHpx5S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR3yE6qrwHw/PSpp8rlyl6QOGe6S1CHDXZI6tOr23N1/lrQauHKXpA4Z7pLUoVW3LSPp9OUffRkfV+6S1CHDXZI6NFK4J1mb5O4kTyQ5mOSKJOcn+W6Sp9rX81rbJLkjyUySh5NcurRDkCSdaNSV++3At6vqLcDbgIPALmBvVW0G9rZzgHcDm9tjJ3DnWHssSZrTnOGe5FzgL4A9AFX166p6CdgG3NWa3QVc0463AV+uge8Da5OsG3O/JUmvYpSrZTYBs8CXkrwNOADcCFxUVUdbm+eAi9rxeuDQ0PMPt7KjQ2Uk2clgZc8b3vCGhfZfkk4bC/mQ5FJd8TPKtswa4FLgzqq6BPgFv9+CAaCqCqj5vHBV7a6qLVW1ZWpqaj5PlSTNYZSV+2HgcFXta+d3Mwj355Osq6qjbdvlhVZ/BNg49PwNrUxa1byGW8tpznCvqueSHEry5qp6EtgKPN4e24FPt6/3tqfcB3wkydeAtwMvD23fSIBBJy21UT+h+nfAV5KcCTwNXM9gS+cbSXYAzwAfaG2/BVwNzAC/bG0lSctopHCvqoeALSep2nqStgXcsLhuSZIWw3vLLAFvKyxppXn7AUnqkCt3SSM5na7h1txcuUtShwx3SeqQ2zLSKuUb/30z3CXcT1Z/3JaRpA65cpc64TaLhrlyl6QOGe6S1CHDXZI6ZLhLUocMd0nqkFfLSJpY/tGXU3PlLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIa+WkU5T3itGi+HKXZI65MpdE8FVrDQ/rtwlqUOGuyR1yHCXpA4Z7pLUId9QnUC+uShpLoa7JJ3CJC+kDHdJS2aSw3HSuecuSR0y3CWpQ27LSAvkloNOZ67cJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUodGDvckZyT5YZL72/mmJPuSzCT5epIzW/lZ7Xym1U8vUd8lSacwn5X7jcDBofNbgduq6k3Ai8COVr4DeLGV39baSZKW0UjhnmQD8B7gC+08wFXA3a3JXcA17XhbO6fVb23tJUnLZNSV++eATwK/becXAC9V1bF2fhhY347XA4cAWv3Lrf0rJNmZZH+S/bOzswvrvSTppOYM9yTvBV6oqgPjfOGq2l1VW6pqy9TU1Di/tSSteqPcW+ZK4H1JrgbOBv4EuB1Ym2RNW51vAI609keAjcDhJGuAc4Gfjb3nkjRPq+l+QHOu3Kvq5qraUFXTwHXAA1X1IeBB4NrWbDtwbzu+r53T6h+oqhprryVJr2ox17l/CrgpyQyDPfU9rXwPcEErvwnYtbguSpLma163/K2q7wHfa8dPA5edpM2vgPePoW+SpAXyE6qS1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUNzhnuSjUkeTPJ4kseS3NjKz0/y3SRPta/ntfIkuSPJTJKHk1y61IOQJL3SKCv3Y8Anqupi4HLghiQXA7uAvVW1GdjbzgHeDWxuj53AnWPvtSTpVc0Z7lV1tKp+0I7/FzgIrAe2AXe1ZncB17TjbcCXa+D7wNok68bdcUnSqc1rzz3JNHAJsA+4qKqOtqrngIva8Xrg0NDTDreyE7/XziT7k+yfnZ2db78lSa9i5HBP8jrgm8DHqurnw3VVVUDN54WrandVbamqLVNTU/N5qiRpDiOFe5LXMAj2r1TVPa34+ePbLe3rC638CLBx6OkbWpkkaZmMcrVMgD3Awar67FDVfcD2drwduHeo/MPtqpnLgZeHtm8kSctgzQhtrgT+BngkyUOt7O+BTwPfSLIDeAb4QKv7FnA1MAP8Erh+nB2WJM1tznCvqv8EcorqrSdpX8ANi+yXJGkR/ISqJHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHVqScE/yriRPJplJsmspXkOSdGpjD/ckZwD/BLwbuBj4YJKLx/06kqRTW4qV+2XATFU9XVW/Br4GbFuC15EkncKaJfie64FDQ+eHgbef2CjJTmBnO/2/JE8u8PUuBH66wOdOKse8OjjmVSC3LmrMf3qqiqUI95FU1W5g92K/T5L9VbVlDF2aGI55dXDMq8NSjXkptmWOABuHzje0MknSMlmKcP9vYHOSTUnOBK4D7luC15EkncLYt2Wq6liSjwDfAc4AvlhVj437dYYsemtnAjnm1cExrw5LMuZU1VJ8X0nSCvITqpLUIcNdkjo00eHe620OkmxM8mCSx5M8luTGVn5+ku8meap9Pa+VJ8kd7efwcJJLV3YEC5PkjCQ/THJ/O9+UZF8b19fbG/QkOaudz7T66RXt+AIlWZvk7iRPJDmY5IpVMMcfb/+mH03y1SRn9zjPSb6Y5IUkjw6VzXtuk2xv7Z9Ksn0+fZjYcO/8NgfHgE9U1cXA5cANbWy7gL1VtRnY285h8DPY3B47gTuXv8tjcSNwcOj8VuC2qnoT8CKwo5XvAF5s5be1dpPoduDbVfUW4G0Mxt7tHCdZD3wU2FJVb2VwwcV19DnP/wy864Syec1tkvOBWxh8CPQy4Jbj/yGMpKom8gFcAXxn6Pxm4OaV7tcSjfVe4K+AJ4F1rWwd8GQ7/jzwwaH2v2s3KQ8Gn4fYC1wF3A+Ewaf21pw43wyuxLqiHa9p7bLSY5jneM8Ffnxivzuf4+OfXj+/zdv9wF/3Os/ANPDoQucW+CDw+aHyV7Sb6zGxK3dOfpuD9SvUlyXTfhW9BNgHXFRVR1vVc8BF7biHn8XngE8Cv23nFwAvVdWxdj48pt+Nt9W/3NpPkk3ALPClthX1hSTn0PEcV9UR4DPAs8BRBvN2gL7nedh853ZRcz7J4d69JK8Dvgl8rKp+PlxXg//Ku7iONcl7gReq6sBK92UZrQEuBe6sqkuAX/D7X9OBvuYYoG0pbGPwH9vrgXP4w62LVWE55naSw73r2xwkeQ2DYP9KVd3Tip9Psq7VrwNeaOWT/rO4Enhfkp8wuIvoVQz2o9cmOf5Bu+Ex/W68rf5c4GfL2eExOAwcrqp97fxuBmHf6xwDvAP4cVXNVtVvgHsYzH3P8zxsvnO7qDmf5HDv9jYHSQLsAQ5W1WeHqu4Djr9jvp3BXvzx8g+3d90vB14e+vXvtFdVN1fVhqqaZjCPD1TVh4AHgWtbsxPHe/zncG1rP1Er3Kp6DjiU5M2taCvwOJ3OcfMscHmS17Z/48fH3O08n2C+c/sd4J1Jzmu/9byzlY1mpd90WOQbFlcD/wP8CPiHle7PGMf15wx+ZXsYeKg9rmaw37gXeAr4D+D81j4Mrhz6EfAIg6sRVnwcCxz7XwL3t+M3Av8FzAD/CpzVys9u5zOt/o0r3e8FjvXPgP1tnv8NOK/3OQb+EXgCeBT4F+CsHucZ+CqD9xV+w+C3tB0LmVvgb9v4Z4Dr59MHbz8gSR2a5G0ZSdIpGO6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ/8PtcPqML+N1vkAAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Lets plot the repartition of the words in each cluster\n",
"# And print the words in a few clusters\n",
"plt.hist(feature_to_cluster.get(), bins='auto')\n",
"print(tfidfvec.vocabulary_[cp.where(feature_to_cluster == 127)[0]])\n",
"print(\"\\n\")\n",
"print(tfidfvec.vocabulary_[cp.where(feature_to_cluster == 632)[0]])"
]
},
{
"cell_type": "code",
"execution_count": 226,
"id": "eb0d6f7d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(316814, 14967)\n",
"(14967, 1000)\n"
]
}
],
"source": [
"# For Categorical Naive Bayes, the count of words is transformed into a count of cluster\n",
"vocab = tfidfvec.vocabulary_\n",
"countvec = CountVectorizer(stop_words='english')\n",
"countvec.vocabulary_ = vocab\n",
"\n",
"x_train = countvec.transform(X_train_text)\n",
"x_test = countvec.transform(X_test_text)\n",
"print(x_train.shape)\n",
"print(feats2cluster.shape)\n",
"\n",
"x_train_cluster = (x_train @ feats2cluster)\n",
"x_test_cluster = (x_test @ feats2cluster)\n",
"\n",
"# For each cluster we will have:\n",
"# - 0: absence of those wprds.\n",
"# - 1: presence of those words\n",
"# - 2: multiple presence of those words (2+)\n",
"\n",
"x_train_cluster.data[x_train_cluster.data > 2] = 2\n",
"x_test_cluster.data[x_test_cluster.data > 2] = 2"
]
},
{
"cell_type": "markdown",
"id": "1a61a25e",
"metadata": {},
"source": [
"Little hack to make sure that if a cluster's max number is 1 in training, it is also 1 in testing"
]
},
{
"cell_type": "code",
"execution_count": 227,
"id": "23582cf5",
"metadata": {},
"outputs": [],
"source": [
"max_one = cp.where(x_train_cluster.max(0).todense() == 1)[1]\n",
"for cluster in max_one:\n",
" samples = (x_test_cluster[:, cluster] > 1)\n",
" if samples.nnz == 0:\n",
" continue\n",
" samples = cp.where(samples.todense())[0]\n",
" x_test_cluster[samples, cluster] = 1"
]
},
{
"cell_type": "markdown",
"id": "ac2198bd",
"metadata": {},
"source": [
"## Categorical model training\n",
"\n",
"Now that the preprocessing is done we can train the Categorical model and see how it performs on these clusters"
]
},
{
"cell_type": "code",
"execution_count": 239,
"id": "6e561526",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/dev/shm/rapids22.04_env/lib/python3.8/site-packages/cuml/naive_bayes/naive_bayes.py:1498: UserWarning: X dtype is not int32. X will be converted, which will increase memory consumption\n",
" warnings.warn(\"X dtype is not int32. X will be \"\n",
"/dev/shm/rapids22.04_env/lib/python3.8/site-packages/cupyx/scipy/sparse/compressed.py:545: UserWarning: Changing the sparsity structure of a csr_matrix is expensive. lil_matrix is more efficient.\n",
" warnings.warn('Changing the sparsity structure of a '\n",
"/dev/shm/rapids22.04_env/lib/python3.8/site-packages/cuml/naive_bayes/naive_bayes.py:1516: UserWarning: X dtype is not int32. X will be converted, which will increase memory consumption\n",
" warnings.warn(\"X dtype is not int32. X will be \"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 110 ms, sys: 4.87 ms, total: 115 ms\n",
"Wall time: 112 ms\n",
"CPU times: user 64.7 ms, sys: 127 ms, total: 191 ms\n",
"Wall time: 193 ms\n"
]
},
{
"data": {
"text/plain": [
"0.9256380200386047"
]
},
"execution_count": 239,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time cnb = CategoricalNB().fit(x_train_cluster, y_train)\n",
"%time cnb.score(x_test_cluster, y_test)"
]
},
{
"cell_type": "code",
"execution_count": 240,
"id": "56a11bc1",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/dev/shm/rapids22.04_env/lib/python3.8/site-packages/sklearn/utils/validation.py:593: FutureWarning: np.matrix usage is deprecated in 1.0 and will raise a TypeError in 1.2. Please convert to a numpy array with np.asarray. For more information see: https://numpy.org/doc/stable/reference/generated/numpy.matrix.html\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 13.7 s, sys: 434 ms, total: 14.2 s\n",
"Wall time: 14.2 s\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/dev/shm/rapids22.04_env/lib/python3.8/site-packages/sklearn/utils/validation.py:593: FutureWarning: np.matrix usage is deprecated in 1.0 and will raise a TypeError in 1.2. Please convert to a numpy array with np.asarray. For more information see: https://numpy.org/doc/stable/reference/generated/numpy.matrix.html\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4.4 s, sys: 110 ms, total: 4.51 s\n",
"Wall time: 4.51 s\n"
]
},
{
"data": {
"text/plain": [
"0.9256379906254438"
]
},
"execution_count": 240,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_train_cluster_np = x_train_cluster.get().todense()\n",
"x_test_cluster_np = x_test_cluster.get().todense()\n",
"y_train_np, y_test_np = y_train.to_numpy(), y_test.to_numpy()\n",
"\n",
"%time cnb = CategoricalNB_sk().fit(x_train_cluster_np, y_train_np)\n",
"%time cnb.score(x_test_cluster_np, y_test_np)"
]
},
{
"cell_type": "markdown",
"id": "14d61f69",
"metadata": {},
"source": [
"# Performance gain"
]
},
{
"cell_type": "code",
"execution_count": 254,
"id": "ae51c493",
"metadata": {},
"outputs": [],
"source": [
"variants = ['Gaussian', 'Bernoulli', 'Multinomial', 'Complement', 'Categorical']\n",
"time_training_cuml = np.array([12300, 26, 63, 69, 112]) / 1000\n",
"time_testing_cuml = np.array([4360, 34, 28, 30, 193])/ 1000\n",
"time_training_sk = np.array([257000, 365, 332, 99, 14200])/ 1000\n",
"time_testing_sk = np.array([316000, 231, 59, 38, 4510])/ 1000\n",
"training_gain = time_training_sk / time_training_cuml\n",
"testing_gain = time_testing_sk / time_testing_cuml"
]
},
{
"cell_type": "code",
"execution_count": 256,
"id": "174c66d2",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 612x720 with 5 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"x = np.arange(2) # the label locations\n",
"width = 0.33 # the width of the bars\n",
"fig = plt.figure(figsize=(8.5,10))\n",
"\n",
"for i in range(5):\n",
" ax = plt.subplot(3,2,i+1)\n",
" ax.bar(x[0] + width/2 + 0.01, time_training_cuml[i], width, color='green')\n",
" ax.bar(x[0] - width/2 - 0.01, time_training_sk[i], width, color='royalblue')\n",
" ax.bar(x[1] + width/2 + 0.01, time_testing_cuml[i], width, color='green')\n",
" ax.bar(x[1] - width/2 - 0.01, time_testing_sk[i], width, color='royalblue')\n",
"\n",
" ax.set_ylabel('Time (s)')\n",
" ax.set_title(variants[i])\n",
" ax.set_xticks(x)\n",
" ax.set_xticklabels([\"train speedup\\n{:.1f}x\".format(training_gain[i]),\n",
" \"test speedup\\n{:.1f}x\".format(testing_gain[i])],\n",
" fontdict={'fontsize': 11,})\n",
" ax.legend(['cuml', 'sklearn'])\n",
"\n",
"fig.tight_layout()\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b513fa16",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment