Last active
July 19, 2022 12:20
-
-
Save lowener/139a63a9b0637173cf340734d1bc01cd to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "id": "67f23185", | |
| "metadata": {}, | |
| "source": [ | |
| "# Categorical\n", | |
| "\n", | |
| "To transform the text to categorical data, you can apply a clustering technique to merge the terms that are similar.\n", | |
| "\n", | |
| "To create these clusters, you could reuse a previously fitted naive Bayes model but just for the purpose of clustering those words." | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "aef23dff", | |
| "metadata": {}, | |
| "source": [ | |
| "## Preprocessing" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 14, | |
| "id": "84ed876e", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "count 1000.000000\n", | |
| "mean 14.967000\n", | |
| "std 19.519501\n", | |
| "min 1.000000\n", | |
| "25% 6.000000\n", | |
| "50% 11.000000\n", | |
| "75% 18.000000\n", | |
| "max 254.000000\n", | |
| "dtype: float64\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# First fit a TfIdf on the train dataset\n", | |
| "tfidfvec = TfidfVectorizer(stop_words='english', min_df=10)\n", | |
| "x_train = tfidfvec.fit_transform(X_train_text)\n", | |
| "\n", | |
| "# Fit a Multinomial on the TdIdf data\n", | |
| "mnb = MultinomialNB().fit(x_train, y_train)\n", | |
| "\n", | |
| "# Use a KMeans algorithm to cluster on what the Multinomial NB learned of the TfIdf.\n", | |
| "# This means that the words that contribute similarly to a category will be clustered together\n", | |
| "km = KMeans(n_clusters=1000, random_state=1)\n", | |
| "feature_to_cluster = km.fit_predict(mnb.feature_log_prob_.T)\n", | |
| "feats2cluster = OneHotEncoder().fit_transform(feature_to_cluster)\n", | |
| "\n", | |
| "# Print statistics on the repartition of the words in the clusters\n", | |
| "print(cudf.Series(feats2cluster.sum(0)[0]).describe())" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "d04509d7", | |
| "metadata": {}, | |
| "source": [ | |
| "Here each cluster holds in average around 15 words" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 225, | |
| "id": "ada1038b", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "47 117\n", | |
| "1597 beam\n", | |
| "2114 broadband\n", | |
| "2435 carriers\n", | |
| "2618 charter\n", | |
| "3788 defective\n", | |
| "4056 dire\n", | |
| "4406 dual\n", | |
| "5367 fixes\n", | |
| "8365 materials\n", | |
| "9072 networking\n", | |
| "10900 recognition\n", | |
| "11466 rollout\n", | |
| "13666 tracker\n", | |
| "14088 unveiling\n", | |
| "Name: token, dtype: object\n", | |
| "\n", | |
| "\n", | |
| "3293 core\n", | |
| "3603 cyber\n", | |
| "4751 enterprise\n", | |
| "5719 gaming\n", | |
| "8738 models\n", | |
| "9801 pc\n", | |
| "10074 platform\n", | |
| "14338 virtual\n", | |
| "Name: token, dtype: object\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAOfklEQVR4nO3db4zdVZ3H8fdnqcCKWcqfCalt3amx0RATF9IghM1mQ11X0VgeoMGYtWGb9AmuKCZadh+QfSaJESHZEBurixvjn0WyENZo3IIP9oHdbZXwr7CMKLRNgdEA7mqMNn73wT3VS22ZOzN3ZnrPvF/Jzfx+55w795w5zadnzv3d36SqkCT15Y9WugOSpPEz3CWpQ4a7JHXIcJekDhnuktShNSvdAYALL7ywpqenV7obkjRRDhw48NOqmjpZ3WkR7tPT0+zfv3+luyFJEyXJM6eqc1tGkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6dFp8QlWnn+ld/z6v9j/59HuWqCeSFsKVuyR1aOJX7q4wJekPuXKXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6NFK4J/l4kseSPJrkq0nOTrIpyb4kM0m+nuTM1vasdj7T6qeXdASSpD8wZ7gnWQ98FNhSVW8FzgCuA24FbquqNwEvAjvaU3YAL7by21o7SdIyGnVbZg3wx0nWAK8FjgJXAXe3+ruAa9rxtnZOq9+aJGPprSRpJHOGe1UdAT4DPMsg1F8GDgAvVdWx1uwwsL4drwcOtecea+0vOPH7JtmZZH+S/bOzs4sdhyRpyCjbMucxWI1vAl4PnAO8a7EvXFW7q2pLVW2Zmppa7LeTJA0ZZVvmHcCPq2q2qn4D3ANcCaxt2zQAG4Aj7fgIsBGg1Z8L/GysvZYkvapRwv1Z4PIkr21751uBx4EHgWtbm+3Ave34vnZOq3+gqmp8XZYkzWWUPfd9DN4Y/QHwSHvObuBTwE1JZhjsqe9pT9kDXNDKbwJ2LUG/JUmvYqS/oVpVtwC3nFD8NHDZSdr+Cnj/4ru2evh3YCWN28T/gezlYPhKmjTefkCSOmS4S1KH3JZRl9xK02rnyl2SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchb/mosvMWudHpx5S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR3yE6qrwHw/PSpp8rlyl6QOGe6S1CHDXZI6tOr23N1/lrQauHKXpA4Z7pLUoVW3LSPp9OUffRkfV+6S1CHDXZI6NFK4J1mb5O4kTyQ5mOSKJOcn+W6Sp9rX81rbJLkjyUySh5NcurRDkCSdaNSV++3At6vqLcDbgIPALmBvVW0G9rZzgHcDm9tjJ3DnWHssSZrTnOGe5FzgL4A9AFX166p6CdgG3NWa3QVc0463AV+uge8Da5OsG3O/JUmvYpSrZTYBs8CXkrwNOADcCFxUVUdbm+eAi9rxeuDQ0PMPt7KjQ2Uk2clgZc8b3vCGhfZfkk4bC/mQ5FJd8TPKtswa4FLgzqq6BPgFv9+CAaCqCqj5vHBV7a6qLVW1ZWpqaj5PlSTNYZSV+2HgcFXta+d3Mwj355Osq6qjbdvlhVZ/BNg49PwNrUxa1byGW8tpznCvqueSHEry5qp6EtgKPN4e24FPt6/3tqfcB3wkydeAtwMvD23fSIBBJy21UT+h+nfAV5KcCTwNXM9gS+cbSXYAzwAfaG2/BVwNzAC/bG0lSctopHCvqoeALSep2nqStgXcsLhuSZIWw3vLLAFvKyxppXn7AUnqkCt3SSM5na7h1txcuUtShwx3SeqQ2zLSKuUb/30z3CXcT1Z/3JaRpA65cpc64TaLhrlyl6QOGe6S1CHDXZI6ZLhLUocMd0nqkFfLSJpY/tGXU3PlLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIa+WkU5T3itGi+HKXZI65MpdE8FVrDQ/rtwlqUOGuyR1yHCXpA4Z7pLUId9QnUC+uShpLoa7JJ3CJC+kDHdJS2aSw3HSuecuSR0y3CWpQ27LSAvkloNOZ67cJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUodGDvckZyT5YZL72/mmJPuSzCT5epIzW/lZ7Xym1U8vUd8lSacwn5X7jcDBofNbgduq6k3Ai8COVr4DeLGV39baSZKW0UjhnmQD8B7gC+08wFXA3a3JXcA17XhbO6fVb23tJUnLZNSV++eATwK/becXAC9V1bF2fhhY347XA4cAWv3Lrf0rJNmZZH+S/bOzswvrvSTppOYM9yTvBV6oqgPjfOGq2l1VW6pqy9TU1Di/tSSteqPcW+ZK4H1JrgbOBv4EuB1Ym2RNW51vAI609keAjcDhJGuAc4Gfjb3nkjRPq+l+QHOu3Kvq5qraUFXTwHXAA1X1IeBB4NrWbDtwbzu+r53T6h+oqhprryVJr2ox17l/CrgpyQyDPfU9rXwPcEErvwnYtbguSpLma163/K2q7wHfa8dPA5edpM2vgPePoW+SpAXyE6qS1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUNzhnuSjUkeTPJ4kseS3NjKz0/y3SRPta/ntfIkuSPJTJKHk1y61IOQJL3SKCv3Y8Anqupi4HLghiQXA7uAvVW1GdjbzgHeDWxuj53AnWPvtSTpVc0Z7lV1tKp+0I7/FzgIrAe2AXe1ZncB17TjbcCXa+D7wNok68bdcUnSqc1rzz3JNHAJsA+4qKqOtqrngIva8Xrg0NDTDreyE7/XziT7k+yfnZ2db78lSa9i5HBP8jrgm8DHqurnw3VVVUDN54WrandVbamqLVNTU/N5qiRpDiOFe5LXMAj2r1TVPa34+ePbLe3rC638CLBx6OkbWpkkaZmMcrVMgD3Awar67FDVfcD2drwduHeo/MPtqpnLgZeHtm8kSctgzQhtrgT+BngkyUOt7O+BTwPfSLIDeAb4QKv7FnA1MAP8Erh+nB2WJM1tznCvqv8EcorqrSdpX8ANi+yXJGkR/ISqJHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHVqScE/yriRPJplJsmspXkOSdGpjD/ckZwD/BLwbuBj4YJKLx/06kqRTW4qV+2XATFU9XVW/Br4GbFuC15EkncKaJfie64FDQ+eHgbef2CjJTmBnO/2/JE8u8PUuBH66wOdOKse8OjjmVSC3LmrMf3qqiqUI95FU1W5g92K/T5L9VbVlDF2aGI55dXDMq8NSjXkptmWOABuHzje0MknSMlmKcP9vYHOSTUnOBK4D7luC15EkncLYt2Wq6liSjwDfAc4AvlhVj437dYYsemtnAjnm1cExrw5LMuZU1VJ8X0nSCvITqpLUIcNdkjo00eHe620OkmxM8mCSx5M8luTGVn5+ku8meap9Pa+VJ8kd7efwcJJLV3YEC5PkjCQ/THJ/O9+UZF8b19fbG/QkOaudz7T66RXt+AIlWZvk7iRPJDmY5IpVMMcfb/+mH03y1SRn9zjPSb6Y5IUkjw6VzXtuk2xv7Z9Ksn0+fZjYcO/8NgfHgE9U1cXA5cANbWy7gL1VtRnY285h8DPY3B47gTuXv8tjcSNwcOj8VuC2qnoT8CKwo5XvAF5s5be1dpPoduDbVfUW4G0Mxt7tHCdZD3wU2FJVb2VwwcV19DnP/wy864Syec1tkvOBWxh8CPQy4Jbj/yGMpKom8gFcAXxn6Pxm4OaV7tcSjfVe4K+AJ4F1rWwd8GQ7/jzwwaH2v2s3KQ8Gn4fYC1wF3A+Ewaf21pw43wyuxLqiHa9p7bLSY5jneM8Ffnxivzuf4+OfXj+/zdv9wF/3Os/ANPDoQucW+CDw+aHyV7Sb6zGxK3dOfpuD9SvUlyXTfhW9BNgHXFRVR1vVc8BF7biHn8XngE8Cv23nFwAvVdWxdj48pt+Nt9W/3NpPkk3ALPClthX1hSTn0PEcV9UR4DPAs8BRBvN2gL7nedh853ZRcz7J4d69JK8Dvgl8rKp+PlxXg//Ku7iONcl7gReq6sBK92UZrQEuBe6sqkuAX/D7X9OBvuYYoG0pbGPwH9vrgXP4w62LVWE55naSw73r2xwkeQ2DYP9KVd3Tip9Psq7VrwNeaOWT/rO4Enhfkp8wuIvoVQz2o9cmOf5Bu+Ex/W68rf5c4GfL2eExOAwcrqp97fxuBmHf6xwDvAP4cVXNVtVvgHsYzH3P8zxsvnO7qDmf5HDv9jYHSQLsAQ5W1WeHqu4Djr9jvp3BXvzx8g+3d90vB14e+vXvtFdVN1fVhqqaZjCPD1TVh4AHgWtbsxPHe/zncG1rP1Er3Kp6DjiU5M2taCvwOJ3OcfMscHmS17Z/48fH3O08n2C+c/sd4J1Jzmu/9byzlY1mpd90WOQbFlcD/wP8CPiHle7PGMf15wx+ZXsYeKg9rmaw37gXeAr4D+D81j4Mrhz6EfAIg6sRVnwcCxz7XwL3t+M3Av8FzAD/CpzVys9u5zOt/o0r3e8FjvXPgP1tnv8NOK/3OQb+EXgCeBT4F+CsHucZ+CqD9xV+w+C3tB0LmVvgb9v4Z4Dr59MHbz8gSR2a5G0ZSdIpGO6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ/8PtcPqML+N1vkAAAAASUVORK5CYII=", | |
| "text/plain": [ | |
| "<Figure size 432x288 with 1 Axes>" | |
| ] | |
| }, | |
| "metadata": { | |
| "needs_background": "light" | |
| }, | |
| "output_type": "display_data" | |
| } | |
| ], | |
| "source": [ | |
| "# Lets plot the repartition of the words in each cluster\n", | |
| "# And print the words in a few clusters\n", | |
| "plt.hist(feature_to_cluster.get(), bins='auto')\n", | |
| "print(tfidfvec.vocabulary_[cp.where(feature_to_cluster == 127)[0]])\n", | |
| "print(\"\\n\")\n", | |
| "print(tfidfvec.vocabulary_[cp.where(feature_to_cluster == 632)[0]])" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 226, | |
| "id": "eb0d6f7d", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "(316814, 14967)\n", | |
| "(14967, 1000)\n" | |
| ] | |
| } | |
| ], | |
| "source": [ | |
| "# For Categorical Naive Bayes, the count of words is transformed into a count of cluster\n", | |
| "vocab = tfidfvec.vocabulary_\n", | |
| "countvec = CountVectorizer(stop_words='english')\n", | |
| "countvec.vocabulary_ = vocab\n", | |
| "\n", | |
| "x_train = countvec.transform(X_train_text)\n", | |
| "x_test = countvec.transform(X_test_text)\n", | |
| "print(x_train.shape)\n", | |
| "print(feats2cluster.shape)\n", | |
| "\n", | |
| "x_train_cluster = (x_train @ feats2cluster)\n", | |
| "x_test_cluster = (x_test @ feats2cluster)\n", | |
| "\n", | |
| "# For each cluster we will have:\n", | |
| "# - 0: absence of those wprds.\n", | |
| "# - 1: presence of those words\n", | |
| "# - 2: multiple presence of those words (2+)\n", | |
| "\n", | |
| "x_train_cluster.data[x_train_cluster.data > 2] = 2\n", | |
| "x_test_cluster.data[x_test_cluster.data > 2] = 2" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "1a61a25e", | |
| "metadata": {}, | |
| "source": [ | |
| "Little hack to make sure that if a cluster's max number is 1 in training, it is also 1 in testing" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 227, | |
| "id": "23582cf5", | |
| "metadata": {}, | |
| "outputs": [], | |
| "source": [ | |
| "max_one = cp.where(x_train_cluster.max(0).todense() == 1)[1]\n", | |
| "for cluster in max_one:\n", | |
| " samples = (x_test_cluster[:, cluster] > 1)\n", | |
| " if samples.nnz == 0:\n", | |
| " continue\n", | |
| " samples = cp.where(samples.todense())[0]\n", | |
| " x_test_cluster[samples, cluster] = 1" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "id": "ac2198bd", | |
| "metadata": {}, | |
| "source": [ | |
| "## Categorical model training\n", | |
| "\n", | |
| "Now that the preprocessing is done we can train the Categorical model and see how it performs on these clusters" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 239, | |
| "id": "6e561526", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/dev/shm/rapids22.04_env/lib/python3.8/site-packages/cuml/naive_bayes/naive_bayes.py:1498: UserWarning: X dtype is not int32. X will be converted, which will increase memory consumption\n", | |
| " warnings.warn(\"X dtype is not int32. X will be \"\n", | |
| "/dev/shm/rapids22.04_env/lib/python3.8/site-packages/cupyx/scipy/sparse/compressed.py:545: UserWarning: Changing the sparsity structure of a csr_matrix is expensive. lil_matrix is more efficient.\n", | |
| " warnings.warn('Changing the sparsity structure of a '\n", | |
| "/dev/shm/rapids22.04_env/lib/python3.8/site-packages/cuml/naive_bayes/naive_bayes.py:1516: UserWarning: X dtype is not int32. X will be converted, which will increase memory consumption\n", | |
| " warnings.warn(\"X dtype is not int32. X will be \"\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 110 ms, sys: 4.87 ms, total: 115 ms\n", | |
| "Wall time: 112 ms\n", | |
| "CPU times: user 64.7 ms, sys: 127 ms, total: 191 ms\n", | |
| "Wall time: 193 ms\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.9256380200386047" | |
| ] | |
| }, | |
| "execution_count": 239, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "%time cnb = CategoricalNB().fit(x_train_cluster, y_train)\n", | |
| "%time cnb.score(x_test_cluster, y_test)" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "execution_count": 240, | |
| "id": "56a11bc1", | |
| "metadata": {}, | |
| "outputs": [ | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/dev/shm/rapids22.04_env/lib/python3.8/site-packages/sklearn/utils/validation.py:593: FutureWarning: np.matrix usage is deprecated in 1.0 and will raise a TypeError in 1.2. Please convert to a numpy array with np.asarray. For more information see: https://numpy.org/doc/stable/reference/generated/numpy.matrix.html\n", | |
| " warnings.warn(\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 13.7 s, sys: 434 ms, total: 14.2 s\n", | |
| "Wall time: 14.2 s\n" | |
| ] | |
| }, | |
| { | |
| "name": "stderr", | |
| "output_type": "stream", | |
| "text": [ | |
| "/dev/shm/rapids22.04_env/lib/python3.8/site-packages/sklearn/utils/validation.py:593: FutureWarning: np.matrix usage is deprecated in 1.0 and will raise a TypeError in 1.2. Please convert to a numpy array with np.asarray. For more information see: https://numpy.org/doc/stable/reference/generated/numpy.matrix.html\n", | |
| " warnings.warn(\n" | |
| ] | |
| }, | |
| { | |
| "name": "stdout", | |
| "output_type": "stream", | |
| "text": [ | |
| "CPU times: user 4.4 s, sys: 110 ms, total: 4.51 s\n", | |
| "Wall time: 4.51 s\n" | |
| ] | |
| }, | |
| { | |
| "data": { | |
| "text/plain": [ | |
| "0.9256379906254438" | |
| ] | |
| }, | |
| "execution_count": 240, | |
| "metadata": {}, | |
| "output_type": "execute_result" | |
| } | |
| ], | |
| "source": [ | |
| "x_train_cluster_np = x_train_cluster.get().todense()\n", | |
| "x_test_cluster_np = x_test_cluster.get().todense()\n", | |
| "y_train_np, y_test_np = y_train.to_numpy(), y_test.to_numpy()\n", | |
| "\n", | |
| "%time cnb = CategoricalNB_sk().fit(x_train_cluster_np, y_train_np)\n", | |
| "%time cnb.score(x_test_cluster_np, y_test_np)" | |
| ] | |
| } | |
| ], | |
| "metadata": { | |
| "kernelspec": { | |
| "display_name": "Python 3 (ipykernel)", | |
| "language": "python", | |
| "name": "python3" | |
| }, | |
| "language_info": { | |
| "codemirror_mode": { | |
| "name": "ipython", | |
| "version": 3 | |
| }, | |
| "file_extension": ".py", | |
| "mimetype": "text/x-python", | |
| "name": "python", | |
| "nbconvert_exporter": "python", | |
| "pygments_lexer": "ipython3", | |
| "version": "3.8.13" | |
| } | |
| }, | |
| "nbformat": 4, | |
| "nbformat_minor": 5 | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment