Skip to content

Instantly share code, notes, and snippets.

@amaarora
Created December 27, 2018 00:34
Show Gist options
  • Save amaarora/a266368fcd9fb5d6fd3cfbbe8bdabc0f to your computer and use it in GitHub Desktop.
Save amaarora/a266368fcd9fb5d6fd3cfbbe8bdabc0f to your computer and use it in GitHub Desktop.
class DecisionTree():
def __init__(self, x, y, idxs=None, min_leaf=5):
if idxs is None: idxs=np.arange(len(y))
self.x,self.y,self.idxs,self.min_leaf = x,y,idxs,min_leaf
self.n,self.c = len(idxs), x.shape[1]
self.val = np.mean(y[idxs])
self.score = float('inf')
self.find_varsplit()
# This just does one decision; we'll make it recursive later
def find_varsplit(self):
for i in range(self.c): self.find_better_split(i)
# We'll write this later!
def find_better_split(self, var_idx): pass
@property
def split_name(self): return self.x.columns[self.var_idx]
@property
def split_col(self): return self.x.values[self.idxs,self.var_idx]
@property
def is_leaf(self): return self.score == float('inf')
def __repr__(self):
s = f'n: {self.n}; val:{self.val}'
if not self.is_leaf:
s += f'; score:{self.score}; split:{self.split}; var:{self.split_name}'
return s
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment