Skip to content

Instantly share code, notes, and snippets.

@vaibkumr
Created October 17, 2018 15:37
Show Gist options
  • Save vaibkumr/c2a493864a78de738ed84543d77b11dd to your computer and use it in GitHub Desktop.
Save vaibkumr/c2a493864a78de738ed84543d77b11dd to your computer and use it in GitHub Desktop.
class RandomForest():
def __init__(self, x, y, n_trees, sample_sz, min_leaf=5, depth = 10):
np.random.seed(42)
self.x, self.y, self.sample_sz, self.min_leaf, self.depth = x, y, sample_sz, min_leaf, depth
self.trees = [self.create_tree() for i in range(n_trees)]
def create_tree(self):
rnd_idxs = np.random.permutation(len(self.y))[:self.sample_sz] #bagging
return DecisionTree(self.x.iloc[rnd_idxs], self.y[rnd_idxs], min_leaf=self.min_leaf, depth = 10)
def predict(self, x):
return np.mean([t.predict(x) for t in self.trees], axis=0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment