Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
Document Classification with scikit-learn
import os
import numpy
from pandas import DataFrame
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn.cross_validation import KFold
from sklearn.metrics import confusion_matrix, f1_score
NEWLINE = '\n'
HAM = 'ham'
SPAM = 'spam'
SOURCES = [
('data/spam', SPAM),
('data/easy_ham', HAM),
('data/hard_ham', HAM),
('data/beck-s', HAM),
('data/farmer-d', HAM),
('data/kaminski-v', HAM),
('data/kitchen-l', HAM),
('data/lokay-m', HAM),
('data/williams-w3', HAM),
('data/BG', SPAM),
('data/GP', SPAM),
('data/SH', SPAM)
]
SKIP_FILES = {'cmds'}
def read_files(path):
for root, dir_names, file_names in os.walk(path):
for path in dir_names:
read_files(os.path.join(root, path))
for file_name in file_names:
if file_name not in SKIP_FILES:
file_path = os.path.join(root, file_name)
if os.path.isfile(file_path):
past_header, lines = False, []
f = open(file_path, encoding="latin-1")
for line in f:
if past_header:
lines.append(line)
elif line == NEWLINE:
past_header = True
f.close()
content = NEWLINE.join(lines)
yield file_path, content
def build_data_frame(path, classification):
rows = []
index = []
for file_name, text in read_files(path):
rows.append({'text': text, 'class': classification})
index.append(file_name)
data_frame = DataFrame(rows, index=index)
return data_frame
data = DataFrame({'text': [], 'class': []})
for path, classification in SOURCES:
data = data.append(build_data_frame(path, classification))
data = data.reindex(numpy.random.permutation(data.index))
pipeline = Pipeline([
('count_vectorizer', CountVectorizer(ngram_range=(1, 2))),
('classifier', MultinomialNB())
])
k_fold = KFold(n=len(data), n_folds=6)
scores = []
confusion = numpy.array([[0, 0], [0, 0]])
for train_indices, test_indices in k_fold:
train_text = data.iloc[train_indices]['text'].values
train_y = data.iloc[train_indices]['class'].values.astype(str)
test_text = data.iloc[test_indices]['text'].values
test_y = data.iloc[test_indices]['class'].values.astype(str)
pipeline.fit(train_text, train_y)
predictions = pipeline.predict(test_text)
confusion += confusion_matrix(test_y, predictions)
score = f1_score(test_y, predictions, pos_label=SPAM)
scores.append(score)
print('Total emails classified:', len(data))
print('Score:', sum(scores)/len(scores))
print('Confusion matrix:')
print(confusion)

This page has moved!

I finally got around to finishing this tutorial and put it on my blog. Please enjoy the finished version here.

@petervtzand

This comment has been minimized.

Copy link

commented Nov 7, 2014

Try the sklearn.linear_model.SGDClassifier (SVM), it will probably give even better results!

Source:
http://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html

@marcfielding1

This comment has been minimized.

Copy link

commented Nov 10, 2014

Just wanted to say thank you for a simply perfect article, i'm looking for a way to classify attachments coming in on emails, keeping an eye out for a certain type of document, this is by far the best written and explained example I've come across.

@zurez

This comment has been minimized.

Copy link

commented Jan 8, 2015

How to save the trained data , so that I don't have to train everytime I use the script?

@cyrisX2

This comment has been minimized.

Copy link

commented Feb 18, 2015

@zurez - you can simply dump it to a file then load it in later.

from sklearn.externals import joblib
joblib.dump(clf, 'my_trained_data.pkl', compress=9)
Then to load it back in
from sklearn.externals import joblib
trained_data = joblib.load('my_trained_data.pkl')

@DanMossa

This comment has been minimized.

Copy link

commented Sep 21, 2016

I've tried to edit the script to use my own text. I get the following error on this line

data = data.reindex(numpy.random.permutation(data.index))

cannot reindex from a duplicate axis

@mohapatras

This comment has been minimized.

Copy link

commented Feb 27, 2017

I have a question. The point where you have initialized SOURCE for creating a list of data and type(Ham/spam). Can we automate it? I mean to ask whether we have to do it manually if I have 500 data or more?

@NagabhushanS

This comment has been minimized.

Copy link

commented Apr 2, 2017

@zurez...Pickle the pipeline (model) object to a file..Later you can load the pickled object.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.