Naive bayes classifier using Python and Kyoto Cabinet
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import operator | |
import struct | |
import kyotocabinet as kc | |
class ClassifierDB(kc.DB): | |
""" | |
Wrapper for `kyotocabinet.DB` that provides utilities for working with | |
features and categories. | |
""" | |
def __init__(self, *args, **kwargs): | |
super(ClassifierDB, self).__init__(*args, **kwargs) | |
self._category_tmpl = 'category.%s' | |
self._feature_to_category_tmpl = 'feature2category.%s.%s' | |
self._total_count = 'total-count' | |
def get_int(self, key): | |
# Kyoto serializes ints big-endian 8-bytes long, so we need to unpack | |
# them using the `struct` module. | |
value = self.get(key) | |
if value: | |
return struct.unpack('>Q', value)[0] | |
return 0 | |
def incr_feature_category(self, feature, category): | |
"""Increment the count for the feature in the given category.""" | |
return self.increment( | |
self._feature_to_category_tmpl % (feature, category), | |
1) | |
def incr_category(self, category): | |
""" | |
Increment the count for the given category, increasing the total | |
count as well. | |
""" | |
self.increment(self._total_count, 1) | |
return self.increment(self._category_tmpl % category, 1) | |
def category_count(self, category): | |
"""Return the number of documents in the given category.""" | |
return self.get_int(self._category_tmpl % category) | |
def total_count(self): | |
"""Return the total number of documents overall.""" | |
return self.get_int(self._total_count) | |
def get_feature_category_count(self, feature, category): | |
"""Get the count of the feature in the given category.""" | |
return self.get_int( | |
self._feature_to_category_tmpl % (feature, category)) | |
def get_feature_counts(self, feature): | |
"""Get the total count for the feature across all categories.""" | |
prefix = self._feature_to_category_tmpl % (feature, '') | |
total = 0 | |
for key in self.match_prefix(prefix): | |
total += self.get_int(key) | |
return total | |
def iter_categories(self): | |
""" | |
Return an iterable that successively yields all the categories | |
that have been observed. | |
""" | |
category_prefix = self._category_tmpl % '' | |
prefix_len = len(category_prefix) | |
for category_key in self.match_prefix(category_prefix): | |
yield category_key[prefix_len:] | |
class NBC(object): | |
""" | |
Simple naive bayes classifier. | |
""" | |
def __init__(self, filename, read_only=False): | |
""" | |
Initialize the classifier by pointing it at a database file. If you | |
intend to only use the classifier for classifying documents, specify | |
`read_only=True`. | |
""" | |
self.filename = filename | |
if not self.filename.endswith('.kct'): | |
raise RuntimeError('Database filename must have "kct" extension.') | |
self.db = ClassifierDB() | |
self.connect(read_only=read_only) | |
def connect(self, read_only=False): | |
""" | |
Open the database. Since Kyoto Cabinet only allows a single writer | |
at a time, the `connect()` method accepts a parameter allowing the | |
database to be opened in read-only mode (supporting multiple readers). | |
If you plan on training the classifier, specify `read_only=False`. | |
If you plan only on classifying documents, it is safe to specify | |
`read_only=True`. | |
""" | |
if read_only: | |
flags = kc.DB.OREADER | |
else: | |
flags = kc.DB.OWRITER | |
self.db.open(self.filename, flags | kc.DB.OCREATE) | |
def close(self): | |
"""Close the database.""" | |
self.db.close() | |
def train(self, features, *categories): | |
""" | |
Increment the counts for the features in the given categories. | |
""" | |
for category in categories: | |
for feature in features: | |
self.db.incr_feature_category(feature, category) | |
self.db.incr_category(category) | |
def feature_probability(self, feature, category): | |
""" | |
Calculate the probability that a particular feature is associated | |
with the given category. | |
""" | |
fcc = self.db.get_feature_category_count(feature, category) | |
if fcc: | |
category_count = self.db.category_count(category) | |
return float(fcc) / category_count | |
return 0 | |
def weighted_probability(self, feature, category, weight=1.0): | |
""" | |
Determine the probability a feature corresponds to the given category. | |
The probability is weighted by the importance of the feature, which | |
is determined by looking at the feature across all categories in | |
which it appears. | |
""" | |
initial_prob = self.feature_probability(feature, category) | |
totals = self.db.get_feature_counts(feature) | |
return ((weight * 0.5) + (totals * initial_prob)) / (weight + totals) | |
def document_probability(self, features, category): | |
""" | |
Calculate the probability that a set of features match the given | |
category. | |
""" | |
feature_probabilities = [ | |
self.weighted_probability(feature, category) | |
for feature in features] | |
return reduce(operator.mul, feature_probabilities, 1) | |
def weighted_document_probability(self, features, category): | |
""" | |
Calculate the probability that a set of features match the given | |
category, and weight that score by the importance of the category. | |
""" | |
if self.db.total_count() == 0: | |
# Avoid divison by zero. | |
return 0 | |
cat_prob = (float(self.db.category_count(category)) / | |
self.db.total_count()) | |
doc_prob = self.document_probability(features, category) | |
return doc_prob * cat_prob | |
def classify(self, features, limit=5): | |
""" | |
Classify the features by finding the categories that match the | |
features with the highest probability. | |
""" | |
probabilities = {} | |
for category in self.db.iter_categories(): | |
probabilities[category] = self.weighted_document_probability( | |
features, | |
category) | |
return sorted( | |
probabilities.items(), | |
key=operator.itemgetter(1), | |
reverse=True)[:limit] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
# import our classifier, assumed to be in same directory | |
from classifier import NBC | |
def train(corpus='corpus'): | |
classifier = NBC(filename='enron.kct') | |
curdir = os.path.dirname(__file__) | |
# paths to spam and ham documents | |
spam_dir = os.path.join(curdir, corpus, 'spam') | |
ham_dir = os.path.join(curdir, corpus, 'ham') | |
# train the classifier with the spam documents | |
train_category(classifier, spam_dir, 'spam') | |
# train the classifier with the ham documents | |
train_category(classifier, ham_dir, 'ham') | |
return classifier | |
def train_category(classifier, path, label): | |
files = os.listdir(path) | |
print 'Preparing to train %s %s files' % (len(files), label) | |
for filename in files: | |
with open(os.path.join(path, filename)) as fh: | |
contents = fh.read() | |
# extract the words from the document | |
features = extract_features(contents) | |
# train the classifier to associate the features with the label | |
classifier.train(features, label) | |
print 'Trained %s files' % len(files) | |
def extract_features(s, min_len=2, max_len=20): | |
""" | |
Extract all the words in the string `s` that have a length within | |
the specified bounds | |
""" | |
words = [] | |
for w in s.lower().split(): | |
wlen = len(w) | |
if wlen > min_len and wlen < max_len: | |
words.append(w) | |
return words | |
def test(classifier, corpus='corpus2'): | |
curdir = os.path.dirname(__file__) | |
# paths to spam and ham documents | |
spam_dir = os.path.join(curdir, corpus, 'spam') | |
ham_dir = os.path.join(curdir, corpus, 'ham') | |
correct = total = 0 | |
for path, label in ((spam_dir, 'spam'), (ham_dir, 'ham')): | |
filenames = os.listdir(path) | |
print 'Preparing to test %s %s files from %s.' % ( | |
len(filenames), | |
label, | |
corpus) | |
for filename in os.listdir(path): | |
with open(os.path.join(path, filename)) as fh: | |
contents = fh.read() | |
# extract the words from the document | |
features = extract_features(contents) | |
results = classifier.classify(features) | |
if results[0][0] == label: | |
correct += 1 | |
total += 1 | |
pct = 100 * (float(correct) / total) | |
print '[%s]: processed %s documents, %02f%% accurate' % (corpus, total, pct) | |
if __name__ == '__main__': | |
classifier = train() | |
test(classifier, 'corpus2') | |
test(classifier, 'corpus3') | |
classifier.close() | |
os.unlink('enron.kct') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
$ python enron.py | |
Preparing to train 1500 spam files | |
Trained 1500 files | |
Preparing to train 3672 ham files | |
Trained 3672 files | |
Preparing to test 3675 spam files from corpus2. | |
Preparing to test 1500 ham files from corpus2. | |
[corpus2]: processed 5175 documents, 90.318841% accurate | |
Preparing to test 4500 spam files from corpus3. | |
Preparing to test 1500 ham files from corpus3. | |
[corpus3]: processed 6000 documents, 85.533333% accurate |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment