Created
February 9, 2013 22:04
-
-
Save mmourafiq/4747307 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DecisionTree(object): | |
""" | |
A decision tree object | |
""" | |
@staticmethod | |
def count_results(data, item=True): | |
""" | |
count the occurrences of each result in the data set | |
""" | |
results_count = defaultdict(int) | |
if item: | |
for i in data: | |
results_count[i.value] += 1 | |
else: | |
results_count = Counter(data) | |
return results_count | |
@staticmethod | |
def divide_data(data, column, value): | |
""" | |
Divides a set of rows on a specific column. | |
""" | |
#a function that decides if the row goes to the first or the second group (true or false) | |
spliter = None | |
if isinstance(value, int) or isinstance(value, float): | |
spliter = lambda item:item.scaled_coords[column] >= value | |
else: | |
spliter = lambda item:item.scaled_coords[column] == value | |
#divide the rows into two sets and return them | |
set_true = [] | |
set_false = [] | |
for item in data: | |
if spliter(item): | |
set_true.append(item) | |
else: | |
set_false.append(item) | |
return (set_true, set_false) | |
@staticmethod | |
def gini_impurity(data, item=True): | |
""" | |
Probability that a randomly placed item will be in the wrong category | |
""" | |
results_count = DecisionTree.count_results(data, item) | |
len_data = len(data) | |
imp = 0.0 | |
for k1, v1 in results_count.iteritems(): | |
p1 = float(v1) / len_data | |
for k2, v2 in results_count.iteritems(): | |
if k1 == k2: continue | |
p2 = float(v2) / len_data | |
imp += p1 * p2 | |
return imp | |
@staticmethod | |
def entropy(data, item=True): | |
""" | |
estimate the disorder in the data set : sum of p(x)log(p(x)) | |
""" | |
results_count = DecisionTree.count_results(data , item) | |
len_data = len(data) | |
ent = 0.0 | |
for v in results_count.itervalues(): | |
p = float(v) / len_data | |
ent -= p * log2(p) | |
return ent | |
@staticmethod | |
def variance(data): | |
""" | |
calculates the statistical variance for a set of rows | |
more preferably to be used with numerical outcomes | |
""" | |
len_data = len(data) | |
if len_data == 0: return 0 | |
score = [float(item.value) for item in data] | |
mean = sum(score) / len(score) | |
variance = sum([(s - mean) ** 2 for s in score]) / len(score) | |
return variance | |
@staticmethod | |
def build_tree(data, disorder_function="entropy"): | |
""" | |
a recursive function that builds the tree by choosing the best dividing criteria | |
disorder_function : | |
for data that contains words and booleans; it is recommended to use entropy or gini_impurity | |
for data that contains number; it is recommended to use variance | |
""" | |
if disorder_function == "entropy": | |
disorder_estimator = DecisionTree.entropy | |
elif disorder_function == "gini_impurity": | |
disorder_estimator = DecisionTree.gini_impurity | |
elif disorder_function == "variance": | |
disorder_estimator = DecisionTree.variance | |
len_data = len(data) | |
if len_data == 0: return Node() | |
current_disorder_level = disorder_estimator(data) | |
# track enhancement of disorer's level | |
best_enhancement = 0.0 | |
best_split = None | |
best_split_sets = None | |
#number columns | |
nbr_coords = len(data[0].scaled_coords) | |
for coord_ind in xrange(nbr_coords): | |
#get unique values of the current column | |
coord_values = {} | |
for item in data: | |
coord_values[item.scaled_coords[coord_ind]] = 1 | |
for coord_value in coord_values.iterkeys(): | |
set1, set2 = DecisionTree.divide_data(data, coord_ind, coord_value) | |
p1 = float(len(set1)) / len_data | |
p2 = (1 - p1) | |
enhancement = current_disorder_level - (p1 * disorder_estimator(set1)) - (p2 * disorder_estimator(set2)) | |
if (enhancement > best_enhancement) and (len(set1) > 0 and len(set2) > 0): | |
best_enhancement = enhancement | |
best_split = (coord_ind, coord_value) | |
best_split_sets = (set1, set2) | |
if best_enhancement > 0: | |
t_node = DecisionTree.build_tree(best_split_sets[0]) | |
f_node = DecisionTree.build_tree(best_split_sets[1]) | |
return Node(col=best_split[0], value=best_split[1], | |
t_node=t_node, f_node=f_node) | |
else: | |
return Node(results=DecisionTree.count_results(data)) | |
@staticmethod | |
def prune(tree, min_enhancement, disorder_function="entropy"): | |
""" | |
checking pairs of nodes that have a common parent to see if merging | |
them would increase the entropy by less than a specified threshold | |
""" | |
if disorder_function == "entropy": | |
disorder_estimator = DecisionTree.entropy | |
elif disorder_function == "gini_impurity": | |
disorder_estimator = DecisionTree.gini_impurity | |
elif disorder_function == "variance": | |
disorder_estimator = DecisionTree.variance | |
if tree.t_node.results == None: | |
DecisionTree.prune(tree.t_node, min_enhancement) | |
if tree.f_node.results == None: | |
DecisionTree.prune(tree.f_node, min_enhancement) | |
# If both the subbranches are now leaves, see if they should merged | |
if (tree.t_node.results != None and tree.f_node.results != None): | |
# Build a combined dataset | |
t_node, f_node = [], [] | |
for key, value in tree.t_node.results.items(): | |
t_node += [[key]] * value | |
for key, value in tree.f_node.results.items(): | |
f_node += [[key]] * value | |
# Test the enhancement | |
delta = disorder_estimator(t_node + f_node, item=False) - (disorder_estimator(t_node, item=False) + disorder_estimator(f_node, item=False) / 2) | |
if delta < min_enhancement: | |
# Merge the branches | |
tree.t_node, tree.f_node = None, None | |
tree.results = DecisionTree.count_results(t_node + f_node, item=False) | |
@staticmethod | |
def classify(observation, tree): | |
""" | |
Classify a new observation given a decision tree | |
""" | |
if tree.results != None: | |
return tree.results | |
#the observation value for the current criteria column | |
observation_value = observation.scaled_coords[tree.col] | |
if observation_value == None: | |
t_results, f_results = DecisionTree.classify(observation, tree.t_node), DecisionTree.classify(observation, tree.f_node) | |
t_count = sum(t_results.values()) | |
f_count = sum(f_results.values()) | |
t_prob = float(t_count) / (t_count + f_count) | |
f_prob = float(f_count) / (t_count + f_count) | |
result = {} | |
for key, value in t_results.items(): result[key] = value * t_prob | |
for key, value in f_results.items(): result[key] = value * f_prob | |
return result | |
else: | |
#with branch to follow | |
branch = None | |
if (isinstance(observation_value, int) or isinstance(observation_value, float)): | |
branch = tree.t_node if (observation_value >= tree.value) else tree.f_node | |
else: | |
branch = tree.t_node if (observation_value == tree.value) else tree.f_node | |
return DecisionTree.classify(observation, branch) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment