Skip to content

Instantly share code, notes, and snippets.

class RandomForest():
def __init__(self, x, y, n_trees, n_features, sample_sz, depth=10, min_leaf=5):
np.random.seed(12)
if n_features == 'sqrt':
self.n_features = int(np.sqrt(x.shape[1]))
elif n_features == 'log2':
self.n_features = int(np.log2(x.shape[1]))
else:
self.n_features = n_features
print(self.n_features, "sha: ",x.shape[1])