Skip to content

Instantly share code, notes, and snippets.

@okanyenigun
Created August 18, 2022 15:48
Show Gist options
  • Save okanyenigun/7087651be34b3ee5460bb501a612e043 to your computer and use it in GitHub Desktop.
Save okanyenigun/7087651be34b3ee5460bb501a612e043 to your computer and use it in GitHub Desktop.
linear regression scala
// to start a spark session
import org.apache.spark.sql.SparkSession
// to use lineer regression model
import org.apache.spark.ml.regression.LinearRegression
//set logging to level of ERROR
import org.apache.log4j._
Logger.getLogger("org").setLevel(Level.ERROR)
//start a spark Session
val spark = SparkSession.builder().getOrCreate()
//read data file
val data = spark.read.option("header","true").option("inferSchema","true").format("csv").load("Clean_Ecommerce.csv")
//check the schema
data.printSchema
//first row of the data
data.head(1)
//get column names
val columnNames = data.columns
//get first row of data
val firstRow = data.head(1)(0)
//loop through columns and print the data for each column on first row
for(i <- Range(1, columnNames.length)){
println(s"Column: ${columnNames(i)} | Data: ${firstRow(i)}")
}
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.Vectors
//("target label", "features")
val df = (data.select(data("Yearly Amount Spent").as("label"),
$"Email", $"Avatar", $"Avg Session Length", $"Time on App",
$"Time on Website", $"Length of Membership"))
//df.printSchema
//create an assembler
val assembler = (new VectorAssembler().setInputCols(Array(
"Avg Session Length", "Time on App",
"Time on Website", "Length of Membership")).setOutputCol("features"))
//get output
val output = assembler.transform(df).select($"label",$"features")
//lineer regression model
val lr = new LinearRegression().setMaxIter(100).setRegParam(0.3).setElasticNetParam(0.8)
//training
val model = lr.fit(output)
//coefficient and intercept of the line
println(s"Coefficients: ${model.coefficients} Intercept: ${model.intercept}")
//summary of the model
//predictions
model.summary.predictions.show()
//residuals
model.summary.residuals.show()
//Root mean square error
println(s"RMSE: ${model.summary.rootMeanSquaredError}")
//mean square error
println(s"MSE: ${model.summary.meanSquaredError}")
//r2 coefficient
println(s"R2: ${model.summary.r2}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment