Skip to content

Instantly share code, notes, and snippets.

@bgweber
Created January 21, 2019 02:02
Show Gist options
  • Save bgweber/064e4b872f2d517eed3aa5c024fcdb00 to your computer and use it in GitHub Desktop.
Save bgweber/064e4b872f2d517eed3aa5c024fcdb00 to your computer and use it in GitHub Desktop.
# spark version
from pyspark.ml.regression import RandomForestRegressor
# define a function to train a RF model and return metrics
def mllib_random_forest(trees, boston_train, boston_test):
# train a random forest regressor with the specified number of trees
rf = RandomForestRegressor(numTrees = trees, labelCol="target")
model = rf.fit(boston_train)
# make predictions
boston_pred = model.transform(boston_test)
r = boston_pred.stat.corr("prediction", "target")
# return the number of trees, and the R value
return [trees, r**2]
# run the tasks
pool.map(lambda trees: mllib_random_forest(trees, boston_train, boston_test), parameters)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment