Skip to content

Instantly share code, notes, and snippets.

@coleifer
Created February 4, 2015 04:28
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 2 You must be signed in to fork a gist
  • Save coleifer/2d66b9671420ca2856a8 to your computer and use it in GitHub Desktop.
Save coleifer/2d66b9671420ca2856a8 to your computer and use it in GitHub Desktop.
Naive bayes classifier using Python and Kyoto Cabinet
import operator
import struct
import kyotocabinet as kc
class ClassifierDB(kc.DB):
"""
Wrapper for `kyotocabinet.DB` that provides utilities for working with
features and categories.
"""
def __init__(self, *args, **kwargs):
super(ClassifierDB, self).__init__(*args, **kwargs)
self._category_tmpl = 'category.%s'
self._feature_to_category_tmpl = 'feature2category.%s.%s'
self._total_count = 'total-count'
def get_int(self, key):
# Kyoto serializes ints big-endian 8-bytes long, so we need to unpack
# them using the `struct` module.
value = self.get(key)
if value:
return struct.unpack('>Q', value)[0]
return 0
def incr_feature_category(self, feature, category):
"""Increment the count for the feature in the given category."""
return self.increment(
self._feature_to_category_tmpl % (feature, category),
1)
def incr_category(self, category):
"""
Increment the count for the given category, increasing the total
count as well.
"""
self.increment(self._total_count, 1)
return self.increment(self._category_tmpl % category, 1)
def category_count(self, category):
"""Return the number of documents in the given category."""
return self.get_int(self._category_tmpl % category)
def total_count(self):
"""Return the total number of documents overall."""
return self.get_int(self._total_count)
def get_feature_category_count(self, feature, category):
"""Get the count of the feature in the given category."""
return self.get_int(
self._feature_to_category_tmpl % (feature, category))
def get_feature_counts(self, feature):
"""Get the total count for the feature across all categories."""
prefix = self._feature_to_category_tmpl % (feature, '')
total = 0
for key in self.match_prefix(prefix):
total += self.get_int(key)
return total
def iter_categories(self):
"""
Return an iterable that successively yields all the categories
that have been observed.
"""
category_prefix = self._category_tmpl % ''
prefix_len = len(category_prefix)
for category_key in self.match_prefix(category_prefix):
yield category_key[prefix_len:]
class NBC(object):
"""
Simple naive bayes classifier.
"""
def __init__(self, filename, read_only=False):
"""
Initialize the classifier by pointing it at a database file. If you
intend to only use the classifier for classifying documents, specify
`read_only=True`.
"""
self.filename = filename
if not self.filename.endswith('.kct'):
raise RuntimeError('Database filename must have "kct" extension.')
self.db = ClassifierDB()
self.connect(read_only=read_only)
def connect(self, read_only=False):
"""
Open the database. Since Kyoto Cabinet only allows a single writer
at a time, the `connect()` method accepts a parameter allowing the
database to be opened in read-only mode (supporting multiple readers).
If you plan on training the classifier, specify `read_only=False`.
If you plan only on classifying documents, it is safe to specify
`read_only=True`.
"""
if read_only:
flags = kc.DB.OREADER
else:
flags = kc.DB.OWRITER
self.db.open(self.filename, flags | kc.DB.OCREATE)
def close(self):
"""Close the database."""
self.db.close()
def train(self, features, *categories):
"""
Increment the counts for the features in the given categories.
"""
for category in categories:
for feature in features:
self.db.incr_feature_category(feature, category)
self.db.incr_category(category)
def feature_probability(self, feature, category):
"""
Calculate the probability that a particular feature is associated
with the given category.
"""
fcc = self.db.get_feature_category_count(feature, category)
if fcc:
category_count = self.db.category_count(category)
return float(fcc) / category_count
return 0
def weighted_probability(self, feature, category, weight=1.0):
"""
Determine the probability a feature corresponds to the given category.
The probability is weighted by the importance of the feature, which
is determined by looking at the feature across all categories in
which it appears.
"""
initial_prob = self.feature_probability(feature, category)
totals = self.db.get_feature_counts(feature)
return ((weight * 0.5) + (totals * initial_prob)) / (weight + totals)
def document_probability(self, features, category):
"""
Calculate the probability that a set of features match the given
category.
"""
feature_probabilities = [
self.weighted_probability(feature, category)
for feature in features]
return reduce(operator.mul, feature_probabilities, 1)
def weighted_document_probability(self, features, category):
"""
Calculate the probability that a set of features match the given
category, and weight that score by the importance of the category.
"""
if self.db.total_count() == 0:
# Avoid divison by zero.
return 0
cat_prob = (float(self.db.category_count(category)) /
self.db.total_count())
doc_prob = self.document_probability(features, category)
return doc_prob * cat_prob
def classify(self, features, limit=5):
"""
Classify the features by finding the categories that match the
features with the highest probability.
"""
probabilities = {}
for category in self.db.iter_categories():
probabilities[category] = self.weighted_document_probability(
features,
category)
return sorted(
probabilities.items(),
key=operator.itemgetter(1),
reverse=True)[:limit]
import os
# import our classifier, assumed to be in same directory
from classifier import NBC
def train(corpus='corpus'):
classifier = NBC(filename='enron.kct')
curdir = os.path.dirname(__file__)
# paths to spam and ham documents
spam_dir = os.path.join(curdir, corpus, 'spam')
ham_dir = os.path.join(curdir, corpus, 'ham')
# train the classifier with the spam documents
train_category(classifier, spam_dir, 'spam')
# train the classifier with the ham documents
train_category(classifier, ham_dir, 'ham')
return classifier
def train_category(classifier, path, label):
files = os.listdir(path)
print 'Preparing to train %s %s files' % (len(files), label)
for filename in files:
with open(os.path.join(path, filename)) as fh:
contents = fh.read()
# extract the words from the document
features = extract_features(contents)
# train the classifier to associate the features with the label
classifier.train(features, label)
print 'Trained %s files' % len(files)
def extract_features(s, min_len=2, max_len=20):
"""
Extract all the words in the string `s` that have a length within
the specified bounds
"""
words = []
for w in s.lower().split():
wlen = len(w)
if wlen > min_len and wlen < max_len:
words.append(w)
return words
def test(classifier, corpus='corpus2'):
curdir = os.path.dirname(__file__)
# paths to spam and ham documents
spam_dir = os.path.join(curdir, corpus, 'spam')
ham_dir = os.path.join(curdir, corpus, 'ham')
correct = total = 0
for path, label in ((spam_dir, 'spam'), (ham_dir, 'ham')):
filenames = os.listdir(path)
print 'Preparing to test %s %s files from %s.' % (
len(filenames),
label,
corpus)
for filename in os.listdir(path):
with open(os.path.join(path, filename)) as fh:
contents = fh.read()
# extract the words from the document
features = extract_features(contents)
results = classifier.classify(features)
if results[0][0] == label:
correct += 1
total += 1
pct = 100 * (float(correct) / total)
print '[%s]: processed %s documents, %02f%% accurate' % (corpus, total, pct)
if __name__ == '__main__':
classifier = train()
test(classifier, 'corpus2')
test(classifier, 'corpus3')
classifier.close()
os.unlink('enron.kct')
$ python enron.py
Preparing to train 1500 spam files
Trained 1500 files
Preparing to train 3672 ham files
Trained 3672 files
Preparing to test 3675 spam files from corpus2.
Preparing to test 1500 ham files from corpus2.
[corpus2]: processed 5175 documents, 90.318841% accurate
Preparing to test 4500 spam files from corpus3.
Preparing to test 1500 ham files from corpus3.
[corpus3]: processed 6000 documents, 85.533333% accurate
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment