Created
September 25, 2018 15:31
-
-
Save boxmein/10974b04ea98e13c6c08899decdf193a to your computer and use it in GitHub Desktop.
Text classification experiment + notes - SVM + Naive Bayes
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Simple text classifier\n", | |
"\n", | |
"Text classification basically works like this:\n", | |
"\n", | |
"1. Get a train dataset\n", | |
"2. Take a document\n", | |
" 1. Remove \"stop words\" - useless words such as \"the\", \"of\", \"or\", ...\n", | |
" 2. Normalize words into their stem form: \"normalize, normalizing, normalized, normalizable\" -> \"normalize\"\n", | |
" 2. Vectorize words in a document: Bag of words / count vectorization, or more advanced approaches\n", | |
" 3. Correct for statistics: divide word count by total word count in a doc, correct by word count in ALL docs\n", | |
"3. Feed the vectors into a classifier\n", | |
"4. Refine model by tuning hyperparameters, adding smarter word vectorizers (Google word2vec, ...)\n", | |
"\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# 20newsgroups dataset\n", | |
"from sklearn.datasets import fetch_20newsgroups\n", | |
"\n", | |
"# text pipeline tools\n", | |
"from sklearn.feature_extraction.text import CountVectorizer\n", | |
"from sklearn.feature_extraction.text import TfidfTransformer\n", | |
"from sklearn.pipeline import Pipeline\n", | |
"\n", | |
"# learning machines: Naive Bayes\n", | |
"from sklearn.naive_bayes import MultinomialNB\n", | |
"from sklearn.linear_model import SGDClassifier\n", | |
"\n", | |
"# essentials\n", | |
"import numpy as np\n", | |
"\n", | |
"# grid searcher, takes a list of various param options \n", | |
"# & outputs the optimal ones from the list\n", | |
"from sklearn.model_selection import GridSearchCV" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"twenty_train = fetch_20newsgroups(subset='train', shuffle=True)\n", | |
"twenty_test = fetch_20newsgroups(subset='test', shuffle=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.81691449814126393" | |
] | |
}, | |
"execution_count": 7, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Naive Bayes\n", | |
"\n", | |
"# NOTE: previously we played through processing twenty_train with\n", | |
"# the count vectorizer & TFIDF transformer\n", | |
"\n", | |
"# count_vectorizer takes a string, and outputs a vector [count_w1, count_w2, ...]\n", | |
"# where each word's occurrence in the string is counted.\n", | |
"\n", | |
"# tfidfTransformer takes one of those vectors, and modifies it to \n", | |
"# lower-weight words that occur in all documents. (and, of, the, ...)\n", | |
"# additionally, it corrects for different document lengths\n", | |
"\n", | |
"nb_classifier = Pipeline([\n", | |
" ('count_vectorizer', CountVectorizer(stop_words='english')),\n", | |
" ('tfidf', TfidfTransformer()),\n", | |
" ('clf', MultinomialNB())\n", | |
"])\n", | |
"\n", | |
"# train\n", | |
"nb_classifier = nb_classifier.fit(twenty_train.data, twenty_train.target)\n", | |
"\n", | |
"# predict on testdata\n", | |
"nb_pred = nb_classifier.predict(twenty_test.data)\n", | |
"\n", | |
"# test score is pretty nice. :)\n", | |
"np.mean(nb_pred == twenty_test.target)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"0.82262347318109397" | |
] | |
}, | |
"execution_count": 6, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Support Vector Machine\n", | |
"\n", | |
"svm_classifier = Pipeline([\n", | |
" # same count vectorizer pipeline as in Naive Bayes\n", | |
" ('count_vectorizer', CountVectorizer(stop_words='english')),\n", | |
" ('tfidf', TfidfTransformer()),\n", | |
" \n", | |
" # generic classifier using SGD with the following params:\n", | |
" # * loss='hinge' to get a linear SVM\n", | |
" # * least-squares penalty\n", | |
" # * alpha is a parameter from which learn rate is derived\n", | |
" # * max_iter = max epochs\n", | |
" # * random_state is the random seed\n", | |
" ('clf-svm', SGDClassifier(loss='hinge', penalty='l2', alpha=1e-3, max_iter=10, random_state=42))\n", | |
"])\n", | |
"\n", | |
"# train\n", | |
"svm_classifier.fit(twenty_train.data, twenty_train.target)\n", | |
"\n", | |
"# predict on testdata\n", | |
"svm_pred = svm_classifier.predict(twenty_test.data)\n", | |
"\n", | |
"# test score at 82% is also pretty epic. :)\n", | |
"np.mean(svm_pred == twenty_test.target)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"(0.90675269577514583,\n", | |
" {'clf__alpha': 0.01,\n", | |
" 'count_vectorizer__ngram_range': (1, 2),\n", | |
" 'tfidf__use_idf': True})" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"# Grid-search the following parameters:\n", | |
"# Ngram ranges from 1-1 to 1-3\n", | |
"# IDF on/off\n", | |
"# Alpha either 1e-2 or 1e-3\n", | |
"\n", | |
"parameter_options = {\n", | |
" 'count_vectorizer__ngram_range': [ (1, 1), (1, 2), (1, 3) ],\n", | |
" 'tfidf__use_idf': (True, False),\n", | |
" 'clf__alpha': (1e-2, 1e-3)\n", | |
"}\n", | |
"\n", | |
"nb_grid_search = GridSearchCV(nb_classifier, parameter_options, n_jobs=2)\n", | |
"nb_grid_search.fit(twenty_train.data, twenty_train.target)\n", | |
"\n", | |
"(nb_grid_search.best_score_, nb_grid_search.best_params_)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.5.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment