-
-
Save kennyballou/e26ddeb469509f059b70 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
'''Sample Naive Bayes Classifier | |
''' | |
import collections | |
import math | |
import sys | |
__author__ = 'Krishnamurthy Koduvayur Viswanathan' | |
__credits__ = ['Kenny Ballou',] | |
class Model(object): | |
'''Simple classifier model''' | |
def __init__(self, arff_file): | |
self.training_file = arff_file | |
# all feature names and possible values | |
self.features = {} | |
# order maintenance -- maintain label order | |
self.feature_name_list = [] | |
# contains tuples of the from (label, feature_name, feature_value) | |
self.feature_counts = collections.defaultdict(lambda: 1) | |
# contains all the values of the label as the last entry | |
self.feature_vectors = [] | |
# smoothing will occur later | |
self.label_counts = collections.defaultdict(lambda: 0) | |
def get_values(self): | |
'''Parse training file and build model''' | |
with open(self.training_file, 'r') as training_file: | |
for line in training_file: | |
line = line.strip().lower() | |
# start of actual data | |
if line[0] != '@': | |
self.feature_vectors.append(line.split(',')) | |
# feature definitions | |
elif ('@data' not in line and | |
(not line.startswith('@relation'))): | |
self.feature_name_list.append(line.strip().split()[1]) | |
feature = line[line.index('{') + 1: | |
line.index('}')].strip().split(',') | |
self.features[self.feature_name_list[-1]] = feature | |
def train_classifier(self): | |
'''Train the model''' | |
for feature_vector in self.feature_vectors: | |
# update count for label | |
self.label_counts[feature_vector[-1]] += 1 | |
for counter in range(0, len(feature_vector)-1): | |
self.feature_counts[ | |
(feature_vector[-1], | |
self.feature_name_list[counter], | |
feature_vector[counter])] += 1 | |
# increase label counts (smoothing). Recall, last element is the label | |
for label in self.label_counts: | |
for feature in self.feature_name_list[:-1]: | |
self.label_counts[label] += len(self.features[feature]) | |
def classify(self, feature_vector): | |
'''Classify features given by feature_vector | |
:param feature_vector: simple list similar to ones given for training | |
''' | |
prob_per_label = {} | |
for label in self.label_counts: | |
log_prob = 0 | |
for feature_value in feature_vector: | |
feature_name = self.feature_name_list[ | |
feature_vector.index(feature_value)] | |
log_prob += math.log( | |
self.feature_counts[(label, feature_name, feature_value)] / | |
self.label_counts[label]) | |
prob_per_label[label] = ((self.label_counts[label] / | |
sum(self.label_counts.values())) * | |
math.exp(log_prob)) | |
print(prob_per_label) | |
return max(prob_per_label, key=lambda c: prob_per_label[c]) | |
def test_classifier(self, arff_file): | |
'''Test our model''' | |
with open(arff_file, 'r') as arff: | |
for line in arff: | |
if line[0] != '@': | |
vector = line.strip().lower().split(',') | |
print("classifier: %s given %s" % ( | |
self.classify(vector), vector[-1])) | |
def main(arff_file): | |
'''main''' | |
model = Model(arff_file) | |
model.get_values() | |
model.train_classifier() | |
model.test_classifier(arff_file) | |
if __name__ == '__main__': | |
assert len(sys.argv[1:]) > 0 | |
main(sys.argv[1]) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment