Skip to content

Instantly share code, notes, and snippets.

@vaibkumr
Created October 22, 2018 12:52
Show Gist options
  • Save vaibkumr/a78d26cf5047190a562736ba58264d3f to your computer and use it in GitHub Desktop.
Save vaibkumr/a78d26cf5047190a562736ba58264d3f to your computer and use it in GitHub Desktop.
def find_better_split(self, var_idx):
x, y = self.x.values[self.idxs,var_idx], self.y[self.idxs]
sort_idx = np.argsort(x)
sort_y,sort_x = y[sort_idx], x[sort_idx]
rhs_cnt,rhs_sum,rhs_sum2 = self.n, sort_y.sum(), (sort_y**2).sum()
lhs_cnt,lhs_sum,lhs_sum2 = 0,0.,0.
for i in range(0,self.n-self.min_leaf-1):
xi,yi = sort_x[i],sort_y[i]
lhs_cnt += 1; rhs_cnt -= 1
lhs_sum += yi; rhs_sum -= yi
lhs_sum2 += yi**2; rhs_sum2 -= yi**2
if i<self.min_leaf or xi==sort_x[i+1]:
continue
lhs_std = std_agg(lhs_cnt, lhs_sum, lhs_sum2)
rhs_std = std_agg(rhs_cnt, rhs_sum, rhs_sum2)
curr_score = lhs_std*lhs_cnt + rhs_std*rhs_cnt
if curr_score<self.score:
self.var_idx,self.score,self.split = var_idx,curr_score,xi
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment