Skip to content

Instantly share code, notes, and snippets.

@vaibkumr
Created October 17, 2018 19:11
Show Gist options
  • Save vaibkumr/f55cae505b253192c905142f023ed7ec to your computer and use it in GitHub Desktop.
Save vaibkumr/f55cae505b253192c905142f023ed7ec to your computer and use it in GitHub Desktop.
class DecisionTree():
def __init__(self, x, y, idxs=None, min_leaf=5, depth = 10):
if idxs is None: idxs=np.arange(len(y)) #bagging with all the rows
self.x, self.y, self.idxs, self.min_leaf, self.depth = x, y, idxs, min_leaf, depth
self.n, self.c = len(idxs), x.shape[1]
self.val = np.mean(y[idxs])
self.score = float('inf')
self.find_varsplit()
# For simplicity it does a single split, make it recursive later
def find_varsplit(self):
for i in range(self.c): self.find_better_split(i)
# A blackbox for now, we'll write this later
def find_better_split(self, var_idx): pass
@property
def split_name(self): return self.x.columns[self.var_idx]
@property
def split_col(self): return self.x.values[self.idxs,self.var_idx]
@property
def is_leaf(self): return self.score == float('inf') and not self.depth
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment