Skip to content

Instantly share code, notes, and snippets.

@vaibkumr
Created October 17, 2018 15:58
Show Gist options
  • Save vaibkumr/1451872f561482f2de1299a9baa3d4bf to your computer and use it in GitHub Desktop.
Save vaibkumr/1451872f561482f2de1299a9baa3d4bf to your computer and use it in GitHub Desktop.
class DecisionTree():
def __init__(self, x, y, idxs=None, min_leaf=5, depth = 10):
if idxs is None: idxs=np.arange(len(y)) #bagging with all the rows
self.x, self.y, self.idxs, self.min_leaf, self.depth = x, y, idxs, min_leaf, depth
self.n, self.c = len(idxs), x.shape[1]
self.val = np.mean(y[idxs])
self.score = float('inf')
self.find_varsplit()
# This just does one decision; we'll make it recursive later
def find_varsplit(self):
for i in range(self.c): self.find_better_split(i)
# We'll write this later!
def find_better_split(self, var_idx): pass
@property
def split_name(self): return self.x.columns[self.var_idx]
@property
def split_col(self): return self.x.values[self.idxs,self.var_idx]
@property
def is_leaf(self): return self.score == float('inf')
def __repr__(self):
s = f'n: {self.n}; val:{self.val}'
if not self.is_leaf:
s += f'; score:{self.score}; split:{self.split}; var:{self.split_name}'
return s
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment