Skip to content

Instantly share code, notes, and snippets.

@zacstewart
Last active March 27, 2023 15:59
  • Star 70 You must be signed in to star a gist
  • Fork 33 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save zacstewart/5978000 to your computer and use it in GitHub Desktop.
Document Classification with scikit-learn
import os
import numpy
from pandas import DataFrame
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn.cross_validation import KFold
from sklearn.metrics import confusion_matrix, f1_score
NEWLINE = '\n'
HAM = 'ham'
SPAM = 'spam'
SOURCES = [
('data/spam', SPAM),
('data/easy_ham', HAM),
('data/hard_ham', HAM),
('data/beck-s', HAM),
('data/farmer-d', HAM),
('data/kaminski-v', HAM),
('data/kitchen-l', HAM),
('data/lokay-m', HAM),
('data/williams-w3', HAM),
('data/BG', SPAM),
('data/GP', SPAM),
('data/SH', SPAM)
]
SKIP_FILES = {'cmds'}
def read_files(path):
for root, dir_names, file_names in os.walk(path):
for path in dir_names:
read_files(os.path.join(root, path))
for file_name in file_names:
if file_name not in SKIP_FILES:
file_path = os.path.join(root, file_name)
if os.path.isfile(file_path):
past_header, lines = False, []
f = open(file_path, encoding="latin-1")
for line in f:
if past_header:
lines.append(line)
elif line == NEWLINE:
past_header = True
f.close()
content = NEWLINE.join(lines)
yield file_path, content
def build_data_frame(path, classification):
rows = []
index = []
for file_name, text in read_files(path):
rows.append({'text': text, 'class': classification})
index.append(file_name)
data_frame = DataFrame(rows, index=index)
return data_frame
data = DataFrame({'text': [], 'class': []})
for path, classification in SOURCES:
data = data.append(build_data_frame(path, classification))
data = data.reindex(numpy.random.permutation(data.index))
pipeline = Pipeline([
('count_vectorizer', CountVectorizer(ngram_range=(1, 2))),
('classifier', MultinomialNB())
])
k_fold = KFold(n=len(data), n_folds=6)
scores = []
confusion = numpy.array([[0, 0], [0, 0]])
for train_indices, test_indices in k_fold:
train_text = data.iloc[train_indices]['text'].values
train_y = data.iloc[train_indices]['class'].values.astype(str)
test_text = data.iloc[test_indices]['text'].values
test_y = data.iloc[test_indices]['class'].values.astype(str)
pipeline.fit(train_text, train_y)
predictions = pipeline.predict(test_text)
confusion += confusion_matrix(test_y, predictions)
score = f1_score(test_y, predictions, pos_label=SPAM)
scores.append(score)
print('Total emails classified:', len(data))
print('Score:', sum(scores)/len(scores))
print('Confusion matrix:')
print(confusion)

This page has moved!

I finally got around to finishing this tutorial and put it on my blog. Please enjoy the finished version here.

@petervtzand
Copy link

Try the sklearn.linear_model.SGDClassifier (SVM), it will probably give even better results!

Source:
http://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html

@marcfielding1
Copy link

Just wanted to say thank you for a simply perfect article, i'm looking for a way to classify attachments coming in on emails, keeping an eye out for a certain type of document, this is by far the best written and explained example I've come across.

@zurez
Copy link

zurez commented Jan 8, 2015

How to save the trained data , so that I don't have to train everytime I use the script?

@cyrisX2
Copy link

cyrisX2 commented Feb 18, 2015

@zurez - you can simply dump it to a file then load it in later.

from sklearn.externals import joblib
joblib.dump(clf, 'my_trained_data.pkl', compress=9)
Then to load it back in
from sklearn.externals import joblib
trained_data = joblib.load('my_trained_data.pkl')

@DanMossa
Copy link

I've tried to edit the script to use my own text. I get the following error on this line

data = data.reindex(numpy.random.permutation(data.index))

cannot reindex from a duplicate axis

@mohapatras
Copy link

I have a question. The point where you have initialized SOURCE for creating a list of data and type(Ham/spam). Can we automate it? I mean to ask whether we have to do it manually if I have 500 data or more?

@NagabhushanS
Copy link

NagabhushanS commented Apr 2, 2017

@zurez...Pickle the pipeline (model) object to a file..Later you can load the pickled object.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment