Skip to content

Instantly share code, notes, and snippets.

@vaibkumr
Created October 22, 2018 12:28
Show Gist options
  • Save vaibkumr/c43c0707740eac4fc0d4433ded1d3cdf to your computer and use it in GitHub Desktop.
Save vaibkumr/c43c0707740eac4fc0d4433ded1d3cdf to your computer and use it in GitHub Desktop.
class RandomForest():
def __init__(self, x, y, n_trees, n_features, sample_sz, depth=10, min_leaf=5):
np.random.seed(12)
if n_features == 'sqrt':
self.n_features = int(np.sqrt(x.shape[1]))
elif n_features == 'log2':
self.n_features = int(np.log2(x.shape[1]))
else:
self.n_features = n_features
print(self.n_features, "sha: ",x.shape[1])
self.x, self.y, self.sample_sz, self.depth, self.min_leaf = x, y, sample_sz, depth, min_leaf
self.trees = [self.create_tree() for i in range(n_trees)]
def create_tree(self):
idxs = np.random.permutation(len(self.y))[:self.sample_sz]
f_idxs = np.random.permutation(self.x.shape[1])[:self.n_features]
return DecisionTree(self.x.iloc[idxs], self.y[idxs], self.n_features, f_idxs,
idxs=np.array(range(self.sample_sz)),depth = self.depth, min_leaf=self.min_leaf)
def predict(self, x):
return np.mean([t.predict(x) for t in self.trees], axis=0)
def std_agg(cnt, s1, s2): return math.sqrt((s2/cnt) - (s1/cnt)**2)
@guaychou
Copy link

excuse me sir, where is the main class ?
Thanks sir

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment