-
-
Save jcrubino/3342581 to your computer and use it in GitHub Desktop.
Naive Bayes Classifier
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
from __future__ import with_statement | |
import collections, operator, math, random, pprint | |
class Classifier(object): | |
AttrsToDump = ["value_counts", "class_counts", "features", "feature_counts"] | |
def __init__(self, features={}, verbose=False): | |
# Holds information about observed values in context of class: | |
# (klass, feature, value) -> int | |
self.value_counts = collections.defaultdict(lambda: 1) | |
# Holds information about observed classes | |
# klass -> int | |
self.class_counts = collections.defaultdict(lambda: 0) | |
# feature -> set of possible values | |
self.features = collections.defaultdict(lambda: set()) | |
self.features.update(features) | |
# Holds information about observed features: | |
# feature -> int | |
self.feature_counts = collections.defaultdict(lambda: 0) | |
self.verbose = verbose | |
def train(self, features, expect): | |
""" | |
@param features list of observed features | |
@param expect expected result | |
""" | |
for feature, value in features.iteritems(): | |
self.features[feature].add(value) | |
self.value_counts[(expect, feature, value)] += 1 | |
self.feature_counts[feature] += len(self.features[feature]) | |
self.class_counts[expect] += 1 | |
def a_priori(self, klass): | |
return self.class_counts[klass] / float(sum(self.class_counts.values())) | |
def p(self, klass, feature, value): | |
counts = self.value_counts[(klass, feature, value)] | |
total = float(self.feature_counts[feature]) | |
return counts / total | |
def classify(self, features, normalize=True): | |
""" | |
@param features {feature : value} | |
@returns (class, probability) | |
""" | |
probabilities = {} | |
for klass in self.class_counts: | |
# | |
# ln P(klass|features) = P(klass) * Product P(feature|klass) | |
# | |
probabilities[klass] = (self.a_priori(klass) * | |
reduce(operator.mul, [self.p(klass, feature, value) for feature, value in features.iteritems()])) | |
if normalize: | |
evidence = sum(probabilities.values()) | |
for k in probabilities: | |
probabilities[k] /= evidence | |
if len(probabilities) == 0: | |
return (None, 1) | |
if self.verbose: | |
pprint.pprint(probabilities) | |
klass = max(probabilities, key=lambda k: probabilities[k]) | |
return (klass, probabilities[klass]) | |
def most_informative_features(self, n=100): | |
features = set() | |
# (feature, value) -> float. | |
max_p = collections.defaultdict(lambda: 0.0) | |
min_p = collections.defaultdict(lambda: 1.0) | |
for triple in self.value_counts: | |
feature = triple[1:] | |
features.add(feature) | |
p = self.p(*triple) | |
max_p[feature] = max(p, max_p[feature]) | |
min_p[feature] = min(p, min_p[feature]) | |
if min_p[feature] == 0: | |
features.discard(feature) | |
features = [(f, max_p[f]/min_p[f]) for f in features] | |
features = sorted(features, key=operator.itemgetter(1)) | |
features.reverse() | |
return features[:n] | |
def untrusted_train(self, features, expected): | |
classified, p = self.classify(features) | |
if classified != expected: | |
p = 1 - p | |
if random.random() >= p: | |
self.train(features, expected) | |
print "trained", features | |
def dump(self, filename): | |
to_dump = [dict(getattr(self, a)) for a in Classifier.AttrsToDump] | |
with open(filename, "w") as f: | |
f.write(pprint.pformat(to_dump)) | |
def load(self, filename): | |
with open(filename) as f: | |
to_load = eval(f.read()) | |
for i, a in enumerate(Classifier.AttrsToDump): | |
getattr(self, a).update(to_load[i]) | |
return self | |
if __name__ == "__main__": | |
names = ('outlook', 'temperature', 'humidity', 'wind') | |
data = map(lambda s: s.strip().split(","), """ | |
Sunny,Hot,High,Weak,Yes | |
Sunny,Hot,High,Strong,Yes | |
Rain,Hot,High,Weak,No | |
Rain,Cool,Normal,Weak,No | |
Rain,Cool,Normal,Strong,No | |
Sunny,Hot,High,Weak,Yes | |
Sunny,Cool,Normal,Weak,Yes | |
Rain,Mild,Normal,Weak,No | |
Sunny,Cool,Normal,Strong,Yes | |
Rain,Hot,High,Strong,No | |
""".strip().splitlines()) | |
c = Classifier() | |
for values in data: | |
expect, features = values[-1], dict(zip(names, values[:-1])) | |
c.untrusted_train(features, expect) | |
error = 0 | |
for values in data: | |
expect, features = values[-1], dict(zip(names, values[:-1])) | |
klass, p = c.classify(features) | |
print values[:-1], (klass, p), "correct? ", expect == klass | |
if expect != klass: | |
error += 1 | |
print "rate = ", 1 - (error / float(len(data))) | |
pprint.pprint(c.most_informative_features()) | |
c.dump('weather.py') | |
c = Classifier().load('weather.py') | |
error = 0 | |
for values in data: | |
expect, features = values[-1], dict(zip(names, values[:-1])) | |
klass, p = c.classify(features) | |
print values[:-1], (klass, p), "correct? ", expect == klass | |
if expect != klass: | |
error += 1 | |
print "rate = ", 1 - (error / float(len(data))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment