Skip to content

Instantly share code, notes, and snippets.

@xiaohan2012
Created January 23, 2014 21:43
Show Gist options
  • Save xiaohan2012/8587397 to your computer and use it in GitHub Desktop.
Save xiaohan2012/8587397 to your computer and use it in GitHub Desktop.
A Naive Bayes spam classifier
from __future__ import division
from itertools import groupby
from collections import Counter
texts = [('spam', ['FREE', 'online', '!!!']),
('safe', ['results', 'repository','online']),
('spam', ['FREE','online','results','FREE', '!!!']),
('spam', ['!!!', 'registration','FREE','!!!']),
('safe', ['conference', 'online', 'registration', 'conference']),
('safe', ['conference', 'results', 'repository', 'rsults'])]
#compute the prob table for classes
clsFreq = Counter (map (lambda (cls, t): cls, texts))
pt = {}
for cls in clsFreq.keys ():
pt [cls] = clsFreq [cls] / sum(clsFreq.values ())
classes = pt.keys ()
dictionary = sorted(list(set([w for cls, words in texts for w in words])))
#compute the CPT
#group texts by cls
textsGroupedByCls = groupby (sorted (texts, key = lambda tpl: tpl [0]), lambda tpl: tpl [0])
cpd = {}
#for each cls
for cls, listOfTexts in textsGroupedByCls:
cpd [cls] = {}
#count the frequency of each word
wordFreq = Counter([w for cls, ts in listOfTexts for w in ts])
print cls, wordFreq
totalCount = sum(wordFreq.values ())
#for each word in the dictionary, calcualte the relative frequency (with smoothing)
for w in dictionary:
cpd [cls][w] = (wordFreq [w] + 1) / (totalCount + len (dictionary))
#cpd [cls][w] = "%d + 1 / (%d + %d)" %(wordFreq [w], totalCount, len (dictionary))
print cpd
#tabular display
for cls, table in cpd.items ():
print cls
words = sorted(table.keys ())
print ' '.join (words)
print ' & '.join(map (lambda w: "%.4f" %table [w], words))
print
#calcualte the posterior probability of the training samples
def posterior (texts, cpd, pt):
result = []
for t in texts:
probs = {}
total = 0
for cls in classes:
probs [cls] = reduce (lambda acc, word: acc * cpd [cls] [word], t, pt [cls])
total += probs [cls]
#normalization
for cls in classes:
probs [cls] /= total
result.append (probs)
return result
pos = posterior(map (lambda (cls, t): t, texts), cpd, pt)
print 'safe \t spam'
for t in pos:
print ' & '.join(map (lambda n: '%.3f' %n, t.values ())), r'\\'
#some classification task
testText = [['FREE', 'online', 'conference', '!!!'], ['conference', 'registration', 'results', 'conference', 'online']]
print posterior (testText, cpd, pt)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment