Skip to content

Instantly share code, notes, and snippets.

@AdroitAnandAI
Last active June 6, 2021 16:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save AdroitAnandAI/85fdc05603c1f578126ba1d07c87eb96 to your computer and use it in GitHub Desktop.
Save AdroitAnandAI/85fdc05603c1f578126ba1d07c87eb96 to your computer and use it in GitHub Desktop.
RandomForest in Spark
import pyspark
from pyspark.ml import Pipeline
from pyspark.ml.regression import RandomForestRegressor
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession
def trainRFmodel(trainingData):
# Train a RandomForest model.
rf = RandomForestRegressor(featuresCol="features")
# Chain indexer and forest in a Pipeline
pipeline = Pipeline(stages=[rf])
# Train model. This also runs the indexer.
model = pipeline.fit(trainingData)
return model
def testRFmodel(model, testData, isLabelled):
# Make predictions.
predictions = model.transform(testData)
# Select example rows to display.
if isLabelled:
predictions.select("prediction", "label", "features").show(5)
# Select (prediction, true label) and compute test error
evaluator = RegressionEvaluator(
labelCol="label", predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(predictions)
print("Root Mean Squared Error (RMSE) on test data = %g" % rmse)
else:
predictions.select("prediction", "features").show(5)
rfModel = model.stages[0]
print(rfModel) # summary only
return predictions.select("prediction", "features")
if not 'sc' in globals():
sc = SparkContext('local')
spark = SparkSession(sc)
df.columns = ['week', 'temp', 'wind', 'rainfall', 'day', 'humScale', 'label']
dfRDD = spark.createDataFrame(df)
transformed_data = assembleFeatures(dfRDD)
transformed_data.show()
# Split the data
(trainingData, testData) = transformed_data.randomSplit([0.9,0.1])
model = trainRFmodel(trainingData)
# True: Actual Labels are present
testRFmodel(model, testData, True)
# False: Actual labels are not present
predictions = testRFmodel(model, transformed_data_2020, False)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment