Skip to content

Instantly share code, notes, and snippets.

@lowener
Last active July 19, 2022 12:20
Show Gist options
  • Select an option

  • Save lowener/139a63a9b0637173cf340734d1bc01cd to your computer and use it in GitHub Desktop.

Select an option

Save lowener/139a63a9b0637173cf340734d1bc01cd to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "67f23185",
"metadata": {},
"source": [
"# Categorical\n",
"\n",
"To transform the text to categorical data, you can apply a clustering technique to merge the terms that are similar.\n",
"\n",
"To create these clusters, you could reuse a previously fitted naive Bayes model but just for the purpose of clustering those words."
]
},
{
"cell_type": "markdown",
"id": "aef23dff",
"metadata": {},
"source": [
"## Preprocessing"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "84ed876e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"count 1000.000000\n",
"mean 14.967000\n",
"std 19.519501\n",
"min 1.000000\n",
"25% 6.000000\n",
"50% 11.000000\n",
"75% 18.000000\n",
"max 254.000000\n",
"dtype: float64\n"
]
}
],
"source": [
"# First fit a TfIdf on the train dataset\n",
"tfidfvec = TfidfVectorizer(stop_words='english', min_df=10)\n",
"x_train = tfidfvec.fit_transform(X_train_text)\n",
"\n",
"# Fit a Multinomial on the TdIdf data\n",
"mnb = MultinomialNB().fit(x_train, y_train)\n",
"\n",
"# Use a KMeans algorithm to cluster on what the Multinomial NB learned of the TfIdf.\n",
"# This means that the words that contribute similarly to a category will be clustered together\n",
"km = KMeans(n_clusters=1000, random_state=1)\n",
"feature_to_cluster = km.fit_predict(mnb.feature_log_prob_.T)\n",
"feats2cluster = OneHotEncoder().fit_transform(feature_to_cluster)\n",
"\n",
"# Print statistics on the repartition of the words in the clusters\n",
"print(cudf.Series(feats2cluster.sum(0)[0]).describe())"
]
},
{
"cell_type": "markdown",
"id": "d04509d7",
"metadata": {},
"source": [
"Here each cluster holds in average around 15 words"
]
},
{
"cell_type": "code",
"execution_count": 225,
"id": "ada1038b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"47 117\n",
"1597 beam\n",
"2114 broadband\n",
"2435 carriers\n",
"2618 charter\n",
"3788 defective\n",
"4056 dire\n",
"4406 dual\n",
"5367 fixes\n",
"8365 materials\n",
"9072 networking\n",
"10900 recognition\n",
"11466 rollout\n",
"13666 tracker\n",
"14088 unveiling\n",
"Name: token, dtype: object\n",
"\n",
"\n",
"3293 core\n",
"3603 cyber\n",
"4751 enterprise\n",
"5719 gaming\n",
"8738 models\n",
"9801 pc\n",
"10074 platform\n",
"14338 virtual\n",
"Name: token, dtype: object\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAOfklEQVR4nO3db4zdVZ3H8fdnqcCKWcqfCalt3amx0RATF9IghM1mQ11X0VgeoMGYtWGb9AmuKCZadh+QfSaJESHZEBurixvjn0WyENZo3IIP9oHdbZXwr7CMKLRNgdEA7mqMNn73wT3VS22ZOzN3ZnrPvF/Jzfx+55w795w5zadnzv3d36SqkCT15Y9WugOSpPEz3CWpQ4a7JHXIcJekDhnuktShNSvdAYALL7ywpqenV7obkjRRDhw48NOqmjpZ3WkR7tPT0+zfv3+luyFJEyXJM6eqc1tGkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6dFp8QlWnn+ld/z6v9j/59HuWqCeSFsKVuyR1aOJX7q4wJekPuXKXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6NFK4J/l4kseSPJrkq0nOTrIpyb4kM0m+nuTM1vasdj7T6qeXdASSpD8wZ7gnWQ98FNhSVW8FzgCuA24FbquqNwEvAjvaU3YAL7by21o7SdIyGnVbZg3wx0nWAK8FjgJXAXe3+ruAa9rxtnZOq9+aJGPprSRpJHOGe1UdAT4DPMsg1F8GDgAvVdWx1uwwsL4drwcOtecea+0vOPH7JtmZZH+S/bOzs4sdhyRpyCjbMucxWI1vAl4PnAO8a7EvXFW7q2pLVW2Zmppa7LeTJA0ZZVvmHcCPq2q2qn4D3ANcCaxt2zQAG4Aj7fgIsBGg1Z8L/GysvZYkvapRwv1Z4PIkr21751uBx4EHgWtbm+3Ave34vnZOq3+gqmp8XZYkzWWUPfd9DN4Y/QHwSHvObuBTwE1JZhjsqe9pT9kDXNDKbwJ2LUG/JUmvYqS/oVpVtwC3nFD8NHDZSdr+Cnj/4ru2evh3YCWN28T/gezlYPhKmjTefkCSOmS4S1KH3JZRl9xK02rnyl2SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchb/mosvMWudHpx5S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR3yE6qrwHw/PSpp8rlyl6QOGe6S1CHDXZI6tOr23N1/lrQauHKXpA4Z7pLUoVW3LSPp9OUffRkfV+6S1CHDXZI6NFK4J1mb5O4kTyQ5mOSKJOcn+W6Sp9rX81rbJLkjyUySh5NcurRDkCSdaNSV++3At6vqLcDbgIPALmBvVW0G9rZzgHcDm9tjJ3DnWHssSZrTnOGe5FzgL4A9AFX166p6CdgG3NWa3QVc0463AV+uge8Da5OsG3O/JUmvYpSrZTYBs8CXkrwNOADcCFxUVUdbm+eAi9rxeuDQ0PMPt7KjQ2Uk2clgZc8b3vCGhfZfkk4bC/mQ5FJd8TPKtswa4FLgzqq6BPgFv9+CAaCqCqj5vHBV7a6qLVW1ZWpqaj5PlSTNYZSV+2HgcFXta+d3Mwj355Osq6qjbdvlhVZ/BNg49PwNrUxa1byGW8tpznCvqueSHEry5qp6EtgKPN4e24FPt6/3tqfcB3wkydeAtwMvD23fSIBBJy21UT+h+nfAV5KcCTwNXM9gS+cbSXYAzwAfaG2/BVwNzAC/bG0lSctopHCvqoeALSep2nqStgXcsLhuSZIWw3vLLAFvKyxppXn7AUnqkCt3SSM5na7h1txcuUtShwx3SeqQ2zLSKuUb/30z3CXcT1Z/3JaRpA65cpc64TaLhrlyl6QOGe6S1CHDXZI6ZLhLUocMd0nqkFfLSJpY/tGXU3PlLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIa+WkU5T3itGi+HKXZI65MpdE8FVrDQ/rtwlqUOGuyR1yHCXpA4Z7pLUId9QnUC+uShpLoa7JJ3CJC+kDHdJS2aSw3HSuecuSR0y3CWpQ27LSAvkloNOZ67cJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUodGDvckZyT5YZL72/mmJPuSzCT5epIzW/lZ7Xym1U8vUd8lSacwn5X7jcDBofNbgduq6k3Ai8COVr4DeLGV39baSZKW0UjhnmQD8B7gC+08wFXA3a3JXcA17XhbO6fVb23tJUnLZNSV++eATwK/becXAC9V1bF2fhhY347XA4cAWv3Lrf0rJNmZZH+S/bOzswvrvSTppOYM9yTvBV6oqgPjfOGq2l1VW6pqy9TU1Di/tSSteqPcW+ZK4H1JrgbOBv4EuB1Ym2RNW51vAI609keAjcDhJGuAc4Gfjb3nkjRPq+l+QHOu3Kvq5qraUFXTwHXAA1X1IeBB4NrWbDtwbzu+r53T6h+oqhprryVJr2ox17l/CrgpyQyDPfU9rXwPcEErvwnYtbguSpLma163/K2q7wHfa8dPA5edpM2vgPePoW+SpAXyE6qS1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUNzhnuSjUkeTPJ4kseS3NjKz0/y3SRPta/ntfIkuSPJTJKHk1y61IOQJL3SKCv3Y8Anqupi4HLghiQXA7uAvVW1GdjbzgHeDWxuj53AnWPvtSTpVc0Z7lV1tKp+0I7/FzgIrAe2AXe1ZncB17TjbcCXa+D7wNok68bdcUnSqc1rzz3JNHAJsA+4qKqOtqrngIva8Xrg0NDTDreyE7/XziT7k+yfnZ2db78lSa9i5HBP8jrgm8DHqurnw3VVVUDN54WrandVbamqLVNTU/N5qiRpDiOFe5LXMAj2r1TVPa34+ePbLe3rC638CLBx6OkbWpkkaZmMcrVMgD3Awar67FDVfcD2drwduHeo/MPtqpnLgZeHtm8kSctgzQhtrgT+BngkyUOt7O+BTwPfSLIDeAb4QKv7FnA1MAP8Erh+nB2WJM1tznCvqv8EcorqrSdpX8ANi+yXJGkR/ISqJHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHVqScE/yriRPJplJsmspXkOSdGpjD/ckZwD/BLwbuBj4YJKLx/06kqRTW4qV+2XATFU9XVW/Br4GbFuC15EkncKaJfie64FDQ+eHgbef2CjJTmBnO/2/JE8u8PUuBH66wOdOKse8OjjmVSC3LmrMf3qqiqUI95FU1W5g92K/T5L9VbVlDF2aGI55dXDMq8NSjXkptmWOABuHzje0MknSMlmKcP9vYHOSTUnOBK4D7luC15EkncLYt2Wq6liSjwDfAc4AvlhVj437dYYsemtnAjnm1cExrw5LMuZU1VJ8X0nSCvITqpLUIcNdkjo00eHe620OkmxM8mCSx5M8luTGVn5+ku8meap9Pa+VJ8kd7efwcJJLV3YEC5PkjCQ/THJ/O9+UZF8b19fbG/QkOaudz7T66RXt+AIlWZvk7iRPJDmY5IpVMMcfb/+mH03y1SRn9zjPSb6Y5IUkjw6VzXtuk2xv7Z9Ksn0+fZjYcO/8NgfHgE9U1cXA5cANbWy7gL1VtRnY285h8DPY3B47gTuXv8tjcSNwcOj8VuC2qnoT8CKwo5XvAF5s5be1dpPoduDbVfUW4G0Mxt7tHCdZD3wU2FJVb2VwwcV19DnP/wy864Syec1tkvOBWxh8CPQy4Jbj/yGMpKom8gFcAXxn6Pxm4OaV7tcSjfVe4K+AJ4F1rWwd8GQ7/jzwwaH2v2s3KQ8Gn4fYC1wF3A+Ewaf21pw43wyuxLqiHa9p7bLSY5jneM8Ffnxivzuf4+OfXj+/zdv9wF/3Os/ANPDoQucW+CDw+aHyV7Sb6zGxK3dOfpuD9SvUlyXTfhW9BNgHXFRVR1vVc8BF7biHn8XngE8Cv23nFwAvVdWxdj48pt+Nt9W/3NpPkk3ALPClthX1hSTn0PEcV9UR4DPAs8BRBvN2gL7nedh853ZRcz7J4d69JK8Dvgl8rKp+PlxXg//Ku7iONcl7gReq6sBK92UZrQEuBe6sqkuAX/D7X9OBvuYYoG0pbGPwH9vrgXP4w62LVWE55naSw73r2xwkeQ2DYP9KVd3Tip9Psq7VrwNeaOWT/rO4Enhfkp8wuIvoVQz2o9cmOf5Bu+Ex/W68rf5c4GfL2eExOAwcrqp97fxuBmHf6xwDvAP4cVXNVtVvgHsYzH3P8zxsvnO7qDmf5HDv9jYHSQLsAQ5W1WeHqu4Djr9jvp3BXvzx8g+3d90vB14e+vXvtFdVN1fVhqqaZjCPD1TVh4AHgWtbsxPHe/zncG1rP1Er3Kp6DjiU5M2taCvwOJ3OcfMscHmS17Z/48fH3O08n2C+c/sd4J1Jzmu/9byzlY1mpd90WOQbFlcD/wP8CPiHle7PGMf15wx+ZXsYeKg9rmaw37gXeAr4D+D81j4Mrhz6EfAIg6sRVnwcCxz7XwL3t+M3Av8FzAD/CpzVys9u5zOt/o0r3e8FjvXPgP1tnv8NOK/3OQb+EXgCeBT4F+CsHucZ+CqD9xV+w+C3tB0LmVvgb9v4Z4Dr59MHbz8gSR2a5G0ZSdIpGO6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ/8PtcPqML+N1vkAAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# Lets plot the repartition of the words in each cluster\n",
"# And print the words in a few clusters\n",
"plt.hist(feature_to_cluster.get(), bins='auto')\n",
"print(tfidfvec.vocabulary_[cp.where(feature_to_cluster == 127)[0]])\n",
"print(\"\\n\")\n",
"print(tfidfvec.vocabulary_[cp.where(feature_to_cluster == 632)[0]])"
]
},
{
"cell_type": "code",
"execution_count": 226,
"id": "eb0d6f7d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(316814, 14967)\n",
"(14967, 1000)\n"
]
}
],
"source": [
"# For Categorical Naive Bayes, the count of words is transformed into a count of cluster\n",
"vocab = tfidfvec.vocabulary_\n",
"countvec = CountVectorizer(stop_words='english')\n",
"countvec.vocabulary_ = vocab\n",
"\n",
"x_train = countvec.transform(X_train_text)\n",
"x_test = countvec.transform(X_test_text)\n",
"print(x_train.shape)\n",
"print(feats2cluster.shape)\n",
"\n",
"x_train_cluster = (x_train @ feats2cluster)\n",
"x_test_cluster = (x_test @ feats2cluster)\n",
"\n",
"# For each cluster we will have:\n",
"# - 0: absence of those wprds.\n",
"# - 1: presence of those words\n",
"# - 2: multiple presence of those words (2+)\n",
"\n",
"x_train_cluster.data[x_train_cluster.data > 2] = 2\n",
"x_test_cluster.data[x_test_cluster.data > 2] = 2"
]
},
{
"cell_type": "markdown",
"id": "1a61a25e",
"metadata": {},
"source": [
"Little hack to make sure that if a cluster's max number is 1 in training, it is also 1 in testing"
]
},
{
"cell_type": "code",
"execution_count": 227,
"id": "23582cf5",
"metadata": {},
"outputs": [],
"source": [
"max_one = cp.where(x_train_cluster.max(0).todense() == 1)[1]\n",
"for cluster in max_one:\n",
" samples = (x_test_cluster[:, cluster] > 1)\n",
" if samples.nnz == 0:\n",
" continue\n",
" samples = cp.where(samples.todense())[0]\n",
" x_test_cluster[samples, cluster] = 1"
]
},
{
"cell_type": "markdown",
"id": "ac2198bd",
"metadata": {},
"source": [
"## Categorical model training\n",
"\n",
"Now that the preprocessing is done we can train the Categorical model and see how it performs on these clusters"
]
},
{
"cell_type": "code",
"execution_count": 239,
"id": "6e561526",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/dev/shm/rapids22.04_env/lib/python3.8/site-packages/cuml/naive_bayes/naive_bayes.py:1498: UserWarning: X dtype is not int32. X will be converted, which will increase memory consumption\n",
" warnings.warn(\"X dtype is not int32. X will be \"\n",
"/dev/shm/rapids22.04_env/lib/python3.8/site-packages/cupyx/scipy/sparse/compressed.py:545: UserWarning: Changing the sparsity structure of a csr_matrix is expensive. lil_matrix is more efficient.\n",
" warnings.warn('Changing the sparsity structure of a '\n",
"/dev/shm/rapids22.04_env/lib/python3.8/site-packages/cuml/naive_bayes/naive_bayes.py:1516: UserWarning: X dtype is not int32. X will be converted, which will increase memory consumption\n",
" warnings.warn(\"X dtype is not int32. X will be \"\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 110 ms, sys: 4.87 ms, total: 115 ms\n",
"Wall time: 112 ms\n",
"CPU times: user 64.7 ms, sys: 127 ms, total: 191 ms\n",
"Wall time: 193 ms\n"
]
},
{
"data": {
"text/plain": [
"0.9256380200386047"
]
},
"execution_count": 239,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time cnb = CategoricalNB().fit(x_train_cluster, y_train)\n",
"%time cnb.score(x_test_cluster, y_test)"
]
},
{
"cell_type": "code",
"execution_count": 240,
"id": "56a11bc1",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/dev/shm/rapids22.04_env/lib/python3.8/site-packages/sklearn/utils/validation.py:593: FutureWarning: np.matrix usage is deprecated in 1.0 and will raise a TypeError in 1.2. Please convert to a numpy array with np.asarray. For more information see: https://numpy.org/doc/stable/reference/generated/numpy.matrix.html\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 13.7 s, sys: 434 ms, total: 14.2 s\n",
"Wall time: 14.2 s\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/dev/shm/rapids22.04_env/lib/python3.8/site-packages/sklearn/utils/validation.py:593: FutureWarning: np.matrix usage is deprecated in 1.0 and will raise a TypeError in 1.2. Please convert to a numpy array with np.asarray. For more information see: https://numpy.org/doc/stable/reference/generated/numpy.matrix.html\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4.4 s, sys: 110 ms, total: 4.51 s\n",
"Wall time: 4.51 s\n"
]
},
{
"data": {
"text/plain": [
"0.9256379906254438"
]
},
"execution_count": 240,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_train_cluster_np = x_train_cluster.get().todense()\n",
"x_test_cluster_np = x_test_cluster.get().todense()\n",
"y_train_np, y_test_np = y_train.to_numpy(), y_test.to_numpy()\n",
"\n",
"%time cnb = CategoricalNB_sk().fit(x_train_cluster_np, y_train_np)\n",
"%time cnb.score(x_test_cluster_np, y_test_np)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment