Created
January 21, 2019 02:01
-
-
Save bgweber/ffd4603a26918ffd77ef3db7e19a55ee to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# sklearn version | |
from sklearn.ensemble import RandomForestRegressor as RFR | |
from multiprocessing.pool import ThreadPool | |
# allow up to 5 concurrent threads | |
pool = ThreadPool(5) | |
# hyperparameters to test out (n_trees) | |
parameters = [ 10, 20, 50] | |
# define a function to train a RF model and return metrics | |
def sklearn_random_forest(trees, X_train, X_test, y_train, y_test): | |
# train a random forest regressor with the specified number of trees | |
rf= RFR(n_estimators = trees) | |
model = rf.fit(X_train, y_train) | |
# make predictions | |
y_pred = model.predict(X_test) | |
r = pearsonr(y_pred, y_test) | |
# return the number of trees, and the R value | |
return [trees, r[0]**2] | |
# run the tasks | |
pool.map(lambda trees: sklearn_random_forest(trees, X_train, | |
X_test, y_train, y_test), parameters) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment