Skip to content

Instantly share code, notes, and snippets.

@rikturr
Created July 21, 2020 14:37
Show Gist options
  • Save rikturr/c74d708cb6b1ca6a0679e5bdf88031e8 to your computer and use it in GitHub Desktop.
Save rikturr/c74d708cb6b1ca6a0679e5bdf88031e8 to your computer and use it in GitHub Desktop.
spark grid search
from pyspark.ml.regression import LinearRegression
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler, StandardScaler
from pyspark.ml.pipeline import Pipeline
indexers = [
StringIndexer(
inputCol=c,
outputCol=f'{c}_idx', handleInvalid='keep')
for c in categorical_feat
]
encoders = [
OneHotEncoder(
inputCol=f'{c}_idx',
outputCol=f'{c}_onehot',
)
for c in categorical_feat
]
num_assembler = VectorAssembler(
inputCols=numeric_feat,
outputCol='num_features',
)
scaler = StandardScaler(inputCol='num_features', outputCol='num_features_scaled')
assembler = VectorAssembler(
inputCols=[f'{c}_onehot' for c in categorical_feat] + ['num_features_scaled'],
outputCol='features',
)
lr = LinearRegression(standardization=False, maxIter=100)
pipeline = Pipeline(
stages=indexers + encoders + [num_assembler, scaler, assembler, lr])
# this is our grid
grid = (
ParamGridBuilder()
.addGrid(lr.elasticNetParam, np.arange(0, 1.01, 0.01))
.addGrid(lr.regParam, [0, 0.5, 1, 2])
.build()
)
crossval = CrossValidator(estimator=pipeline,
estimatorParamMaps=grid,
evaluator=RegressionEvaluator(),
numFolds=3)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment