Last active
June 6, 2021 16:50
-
-
Save AdroitAnandAI/85fdc05603c1f578126ba1d07c87eb96 to your computer and use it in GitHub Desktop.
RandomForest in Spark
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pyspark | |
from pyspark.ml import Pipeline | |
from pyspark.ml.regression import RandomForestRegressor | |
from pyspark.ml.evaluation import RegressionEvaluator | |
from pyspark.context import SparkContext | |
from pyspark.sql.session import SparkSession | |
def trainRFmodel(trainingData): | |
# Train a RandomForest model. | |
rf = RandomForestRegressor(featuresCol="features") | |
# Chain indexer and forest in a Pipeline | |
pipeline = Pipeline(stages=[rf]) | |
# Train model. This also runs the indexer. | |
model = pipeline.fit(trainingData) | |
return model | |
def testRFmodel(model, testData, isLabelled): | |
# Make predictions. | |
predictions = model.transform(testData) | |
# Select example rows to display. | |
if isLabelled: | |
predictions.select("prediction", "label", "features").show(5) | |
# Select (prediction, true label) and compute test error | |
evaluator = RegressionEvaluator( | |
labelCol="label", predictionCol="prediction", metricName="rmse") | |
rmse = evaluator.evaluate(predictions) | |
print("Root Mean Squared Error (RMSE) on test data = %g" % rmse) | |
else: | |
predictions.select("prediction", "features").show(5) | |
rfModel = model.stages[0] | |
print(rfModel) # summary only | |
return predictions.select("prediction", "features") | |
if not 'sc' in globals(): | |
sc = SparkContext('local') | |
spark = SparkSession(sc) | |
df.columns = ['week', 'temp', 'wind', 'rainfall', 'day', 'humScale', 'label'] | |
dfRDD = spark.createDataFrame(df) | |
transformed_data = assembleFeatures(dfRDD) | |
transformed_data.show() | |
# Split the data | |
(trainingData, testData) = transformed_data.randomSplit([0.9,0.1]) | |
model = trainRFmodel(trainingData) | |
# True: Actual Labels are present | |
testRFmodel(model, testData, True) | |
# False: Actual labels are not present | |
predictions = testRFmodel(model, transformed_data_2020, False) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment