Skip to content

Instantly share code, notes, and snippets.

@bgweber
Created January 21, 2019 02:01
Show Gist options
  • Save bgweber/096c24daf382c995a8086e535338395a to your computer and use it in GitHub Desktop.
Save bgweber/096c24daf382c995a8086e535338395a to your computer and use it in GitHub Desktop.
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import RegressionEvaluator
crossval = CrossValidator(estimator=LinearRegression(labelCol = "target"),
estimatorParamMaps=ParamGridBuilder().addGrid(
LinearRegression.elasticNetParam, [0, 0.5, 1.0]).build(),
evaluator=RegressionEvaluator(
labelCol = "target", metricName = "r2"),
numFolds=10)
# cross validate the model and select the best fit
cvModel = crossval.fit(boston_train)
model = cvModel.bestModel
# calculate results
boston_pred = model.transform(boston_test)
r = boston_pred.stat.corr("prediction", "target")
print("R-sqaured: " + str(r**2))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment