Skip to content

Instantly share code, notes, and snippets.

@Alescontrela
Created January 1, 2019 02:20
Show Gist options
  • Save Alescontrela/f914f12147421ddaa5719304c8a2feee to your computer and use it in GitHub Desktop.
Save Alescontrela/f914f12147421ddaa5719304c8a2feee to your computer and use it in GitHub Desktop.
class DecisionTree(object):
def __init__(self, x, y, idxs = None, min_leaf = 5):
"""
Create a decision tree by computing what feature from the observation x to
perform the current split on. Best feature is computed as that which results
in the minimum standard deviation across the input examples. Split value
is the value of the best feature at which to perform the split.
"""
# ids of data samples to use for the creation of the current decision tree
if idxs is None: idxs = np.arange(len(y))
# get observations, labels, and minimum number of samples per leaf
self.x, self.y, self.idxs, self.min_leaf = x, y, idxs, min_leaf
self.n, self.c = len(idxs), x.shape[1] # num. samples and features per observation
self.val = np.mean(y[idxs]) # Value for current split equals the
self.score = float('inf') # score of decision tree
self.find_varsplit()
def find_varsplit(self):
"""
Determine the feature to perform the current split on if such a feature exists. If so,
create two children: one for values lower than the split value and the other for values higher
than the split val.
"""
# test all features to find which one returns the lowest standard deviation
for i in range(self.c): self.find_better_split(i)
if self.score == float('inf'): return # no split was found
x = self.split_col
lhs = np.nonzero(x<=self.split)[0]; rhs = np.nonzero(x>self.split)[0]
self.lhs = DecisionTree(self.x, self.y, self.idxs[lhs])
self.rhs = DecisionTree(self.x, self.y, self.idxs[rhs])
def find_better_split(self, var_idx):
"""
Determine whether the current feature (var_idx) is the best feature to perform the split with.
If the aggregated standard deviation of the current feature is the lowest, update the current
split score and split value.
"""
x,y = self.x.values[self.idxs, var_idx], self.y[self.idxs]
sort_idx = np.argsort(x.T).T # sort samples by feature value
sort_y, sort_x = y[sort_idx], x[sort_idx]
rhs_cnt, rhs_sum, rhs_sum2 = self.n, np.sum(sort_y), (sort_y**2).sum()
lhs_cnt, lhs_sum, lhs_sum2 = 0., 0., 0.
for i in range(0, self.n-self.min_leaf-1):
xi, yi = sort_x[i], sort_y[i]
lhs_cnt += 1; rhs_cnt -= 1
lhs_sum += yi; rhs_sum -= yi
lhs_sum2 += yi**2; rhs_sum2 -= yi**2
if i < self.min_leaf or xi == sort_x[i+1]: continue
# Calculate the standard deviation of the labels less than and greater than the current x value
lhs_std = DecisionTree.std_agg(lhs_cnt, lhs_sum, lhs_sum2)
rhs_std = DecisionTree.std_agg(rhs_cnt, rhs_sum, rhs_sum2)
curr_score = lhs_std*lhs_cnt + rhs_std*rhs_cnt
if curr_score < self.score:
self.var_idx, self.score, self.split = var_idx, curr_score, xi
@property
def split_name(self): return self.x.columns[self.var_idx]
@property
def split_col(self): return self.x.values[self.idxs, self.var_idx]
@property
def is_leaf(self): return self.score == float('inf')
def __repr__(self):
s = f'n: {self.n}--val: {self.val}'
print(self.is_leaf)
if not self.is_leaf:
s+= f'--score:{self.score}--split: {self.split}--var: {self.split_name}--var_idx: {self.var_idx}'
return s
def predict(self, x, debug):
"""
Form predictions for input observation by recursing through decision tree until leaf encountered.
The split value for this leaf is the prediction.
"""
return np.array([self.predict_row(xi, debug) for xi in x])
def predict_row(self, xi, debug):
"""
Predict value of input. debug parameter specifies whether split decision should be broadcasted.
"""
if self.is_leaf: return self.val
if debug: print(self.split_name, end = " ")
t = self.lhs if xi[self.var_idx] <= self.split else self.rhs
if debug:
if t==self.lhs:
print("less than", end = " ")
else:
print("greater than", end = " ")
if debug: print(self.split)
return t.predict_row(xi, debug)
@staticmethod
def std_agg(cnt, s1, s2):
"""
Compute the aggregated standard deviation of the value s
"""
return math.sqrt(np.abs((s2/cnt) - (s1/cnt)**2))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment