Skip to content

Instantly share code, notes, and snippets.

@nyanshell
Last active December 17, 2017 15:42
Show Gist options
  • Save nyanshell/9802314 to your computer and use it in GitHub Desktop.
Save nyanshell/9802314 to your computer and use it in GitHub Desktop.
A Naive Decision Tree Practice, only for discrete values (ID3 algorithm)
# -*- coding: utf-8 -*--
"""
A Naive Decision Tree Practice, Only for discrete values
ID3 algorithm
"""
import math
import copy
from collections import defaultdict
class dec_tree():
def _attr_in_value_v(self, v, samples, attr):
# print v, attr, samples
return [s for s in samples if self.x[s][attr] == v]
def _entropy(self, s): # samples index
p = defaultdict(int)
for i in iter(s):
p[self.y[i]] += 1
#print p
return -sum(float(p[i])/len(s) * math.log(float(p[i])/len(s), 2) for i in p)
def _gain(self, samples, attr): # attr: attribute index to count
return self._entropy(samples) - sum(
self._entropy(self._attr_in_value_v(v, samples, attr))*len(self._attr_in_value_v(v, samples, attr))/len(samples)\
for v in iter(self.target_values))
def fit(self, x, y):
assert len(x) == len(y)
self.x = x
self.y = y
# count attributes number & values in each attribute
self.attr_values = [set() for i in range(0, len(x[0]))]
self.target_values = set([vy for vy in self.y])
for vx in x:
for id, ivx in enumerate(vx):
self.attr_values[id].add(ivx)
# init decision tree
self.tree = [{"children":[], "attr_index": -1}]
# build decision tree
self.build_tree(range(0, len(self.x)), {}, 0)
#for id, t in enumerate(self.tree):
# print id, t
def build_tree(self, samples_list, attr_dict, parent):
"""
sample list: sample index number list
attr_dict: { attr_index: attr_value }, selected attributes
"""
example_cnt = defaultdict(int)
for s in samples_list:
example_cnt[self.y[s]] += 1
#print parent, samples_list, example_cnt, len(attr_dict)
if len(example_cnt) == 1:
self.tree.append({"parent": parent,
"target": example_cnt.keys()[0]})
self.tree[parent]["children"].append(len(self.tree) - 1)
elif len(attr_dict) == len(self.x[0]):
print example_cnt
self.tree.append({"parent": parent,
"target": max(example_cnt)})
self.tree[parent]["children"].append(len(self.tree) - 1)
elif len(example_cnt) > 1:
attr_available = set(range(0, len(self.x[0]))) - set(attr_dict.keys())
# print "ava attr", attr_available
best_attr = max({i: self._gain(samples_list, i) for i in attr_available})
#print "best attr", best_attr
for v in iter(self.attr_values[best_attr]):
#print attr_dict,
temp_attr_dict = copy.copy(attr_dict)
temp_attr_dict[best_attr] = v
#print temp_attr_dict, attr_dict, best_attr, v, attr_available
new_samples_list = [i for i in samples_list if self.x[i][best_attr] == v]
if new_samples_list:
#print "insert verx", parent, v, new_samples_list
self.tree.append({"parent": parent,
"attr_index": best_attr,
"attr_value": v,
"children":[]})
self.tree[parent]["children"].append(len(self.tree) - 1)
self.build_tree(new_samples_list, temp_attr_dict, len(self.tree) - 1)
else: # most common value
#print example_cnt, parent, len(self.tree), v
self.tree.append({"parent": parent,
"attr_index": best_attr,
"attr_value": v,
"target": max(example_cnt)})
self.tree[parent]["children"].append(len(self.tree) - 1)
else:
raise ValueError("fuck")
def predict(self, x):
p = 0
while True:
if "target" in self.tree[p]:
return self.tree[p]["target"]
#print p, self.tree[p],
#raw_input()
if len(self.tree[p]["children"]) == 1:
p = self.tree[p]["children"][0]
continue
for c in self.tree[p]["children"]:
if x[self.tree[c]["attr_index"]] == self.tree[c]["attr_value"] and "target" in self.tree[c]:
return self.tree[c]["target"]
elif x[self.tree[c]["attr_index"]] == self.tree[c]["attr_value"]:
p = c
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment