Skip to content

Instantly share code, notes, and snippets.

@ty60
Created December 24, 2017 05:40
Show Gist options
  • Save ty60/ed85b2811019960d74be77c5de65e2d6 to your computer and use it in GitHub Desktop.
Save ty60/ed85b2811019960d74be77c5de65e2d6 to your computer and use it in GitHub Desktop.
Oreore implementation of naive bayes
import random
import sys
import json
import numpy as np
from vector import translate_vector
from naive_bayes import NaiveBayes
def split_dataset(vectors, answer_classes, num_bucket = 10):
vec_count = len(vectors)
ids = range(vec_count)
random.shuffle(ids)
hop_len = vec_count / num_bucket
boundary_id = hop_len
for _ in xrange(num_bucket - 1):
train_ids = ids[:boundary_id] + ids[boundary_id + hop_len:]
validate_ids = ids[boundary_id:boundary_id + hop_len]
boundary_id += hop_len
train_vectors = list()
train_ans_classes = list()
for vec_id in train_ids:
train_vectors.append(vectors[vec_id])
train_ans_classes.append(answer_classes[vec_id])
val_vectors = list()
val_ans_classes = list()
for vec_id in validate_ids:
val_vectors.append(vectors[vec_id])
val_ans_classes.append(answer_classes[vec_id])
train_pairs = (train_vectors, train_ans_classes)
val_pairs = (val_vectors, val_ans_classes)
yield train_pairs, val_pairs
def test_nb(dataset_vectors, answer_classes):
val_round = 0
correction_rates = list()
for train_pairs, val_pairs in split_dataset(dataset_vectors, answer_classes):
print "[+] Round {}".format(val_round)
print "[+] Training..."
nb = NaiveBayes()
nb.train(*train_pairs)
#print json.dumps(nb.features, indent = 4)
print "[+] Validating..."
correct_count = 0
ans_vec, ans_clss = val_pairs
for vec, ans_cls in zip(ans_vec, ans_clss):
#print "{}: {};".format(vec, ans_cls),
predicted_class = nb.predict(vec)
is_correct = predicted_class == ans_cls
if is_correct:
correct_count += 1
correct_rate = float(correct_count) / len(ans_vec)
correction_rates.append(correct_rate)
print "[+] Correct rate = {}".format(correct_rate)
print ""
val_round += 1
print "[+] Average correction rate = {}".format(np.average(correction_rates))
if __name__ == "__main__":
with open("../data/iris.data") as f:
lines = f.readlines()
vectors = list()
answers = list()
round_to_int = lambda x: int(round(x))
for line in lines:
if not line.rstrip("\n"):
continue
data = line.split(",")
sizes = map(float, data[:4])
sizes = map(round_to_int, sizes)
iris_type = data[4]
vectors.append(sizes)
answers.append(iris_type)
test_nb(vectors, answers)
# References:
# http://aidiary.hatenablog.com/entry/20100613/1276389337
# http://sonickun.hatenablog.com/entry/2014/06/13/224501
import math
from collections import defaultdict
from functools import partial
from vector import translate_vector
class NaiveBayes(object):
LAPLACE_EPS = 1
@classmethod
def gen_counter(cls, depth):
"""gen_counter
Generate and return multi dimentional defaultdictionary, with type int.
"""
if depth <= 1:
return defaultdict(int)
else:
t = partial(NaiveBayes.gen_counter, depth - 1)
return defaultdict(t)
def __init__(self):
self.classes = list()
self.classes_count = NaiveBayes.gen_counter(1)
self.features = NaiveBayes.gen_counter(3)
self.elements = set()
def train(self, train_vectors, classes):
"""train
Count the number of elements of specific values
for each classes.
"""
translated_vecs = map(translate_vector, train_vectors)
for vector, class_ in zip(translated_vecs, classes):
self.classes_count[class_] += 1
# Count the number of observed specific values
# for each vector element.
# e.g.) The counter for, objserved value 10 for the 2nd element of the vector,
# which the vector is classified as 'classname' would be stored at:
# features[classname][2][10]
feature_counters = self.features[class_]
for elem_id, elem_value in vector.elements():
feature_counters[elem_id][elem_value] += 1
self.elements.add(elem_id)
self.classes = self.classes_count.keys()
def _likelihood(self, elem_id, elem_value, class_):
"""_likelihood
Likelihood of vector elem_id being elem_valud in class class_.
"""
feature_counters = self.features[class_]
molecule = feature_counters[elem_id][elem_value] + NaiveBayes.LAPLACE_EPS
denominator = self.classes_count[class_] + len(self.elements)
return float(molecule) / float(denominator)
def _score(self, vector, class_):
"""_score
Calculate the score which will be used instead of posteriori probability.
Use log to prevend the likelihood from underflowing.
Since log is used ADD the log(independent probabilities),
instead of multiplying it.
"""
total = sum(self.classes_count.values())
score = math.log(self.classes_count[class_] / float(total))
translated_vec = translate_vector(vector)
for elem_id, elem_value in translated_vec.elements():
score += math.log(self._likelihood(elem_id, elem_value, class_))
return score
def predict(self, vector):
"""predict
Predict which class vector is the most likely to be.
Such class is the class with the highest score.
"""
tag_score_to_class = lambda cls: (cls, self._score(vector, cls))
scores = map(tag_score_to_class, self.classes)
get_tagged_score = lambda t: t[1]
predicted_class = max(scores, key = get_tagged_score)[0]
return predicted_class
import abc
class Vector(object):
__metaclass__ = abc.ABCMeta
def __init__(self, vec):
self._vec = vec
@abc.abstractmethod
def get_val(self, elem_id):
""""""
@abc.abstractmethod
def elements(self):
""""""
class DenseVector(Vector):
"""DenseVector
Vectors shold be represented with a list.
e.g.) [1, 0, 0, 2]
"""
def __init__(self, vec):
super(DenseVector, self).__init__(vec)
def get_val(self, elem_id):
return self._vec[elem_id]
def elements(self):
for elem_id, elem_val in enumerate(self._vec):
yield elem_id, elem_val
class SparseVector(Vector):
"""SparseVector
Vectors shold be represented with a dict.
e.g.) [1, 0, 0, 2] -> {0: 1, 3: 2}
"""
def __init__(self, vec):
super(SparseVector, self).__init__(vec)
def get_val(self, elem_id):
return self._vec[elem_id]
def elements(self):
# key: elem_id, val: elem_val
for item in self._vec.iteritems():
yield item
def translate_vector(vector):
if isinstance(vector, list):
return DenseVector(vector)
elif isinstance(vector, dict):
return SparseVector(vector)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment