Skip to content

Instantly share code, notes, and snippets.

@JRuumis
Created August 18, 2016 18:24
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save JRuumis/5f7600e54eece42dd0c4c1e1543df84e to your computer and use it in GitHub Desktop.
import org.apache.log4j.{Level, Logger}
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.sql.{DataFrame, SparkSession}
/**
* Created by Janis Rumnieks on 15/08/2016.
*/
object DigitRecognizer4 {
def labelAndFeaturesFromCsv(sparkSession: SparkSession, csvPath: String): DataFrame = {
val csvLines = scala.io.Source.fromFile(csvPath).getLines() map (_.split(",") )
val csvHeader = csvLines.next()
val csvRows = csvLines map (row => (row map (elem => elem.toDouble))) toList
val csvTuplesWithVectors = csvRows map (row => (row.head, Vectors.dense(row.tail) ))
val dataFrame = sparkSession.createDataFrame(csvTuplesWithVectors).toDF("label","features")
dataFrame
}
def justFeaturesFromCsv(sparkSession: SparkSession, csvPath: String): DataFrame = {
val csvLines = scala.io.Source.fromFile(csvPath).getLines() map (_.split(",") )
val csvHeader = csvLines.next()
val csvRows = csvLines map (row => (row map (elem => elem.toDouble))) toList
val csvTuplesWithVectors = csvRows map (row => (0.0, Vectors.dense(row) ))
val dataFrame = sparkSession.createDataFrame( csvTuplesWithVectors ).toDF("label","features")
dataFrame
}
def main(args: Array[String]): Unit = {
//val trainDataFile = """C:\Developer\Kaggle\DigitRecogniser\train.csv"""
val trainDataFile = """C:\Developer\Kaggle\DigitRecogniser\train_small_3.csv"""
val testDataFile = """C:\Developer\Kaggle\DigitRecogniser\test.csv"""
// disable INFO messages
Logger.getLogger("org").setLevel(Level.ERROR)
Logger.getLogger("akka").setLevel(Level.ERROR)
val sparkSession = SparkSession
.builder()
.appName("Spark SQL - Digit Recognition")
.master("local[*]")
.config("spark.sql.warehouse.dir", ".")
.getOrCreate()
val trainDataFrame: DataFrame = labelAndFeaturesFromCsv(sparkSession, trainDataFile)
val testDataFrame: DataFrame = justFeaturesFromCsv(sparkSession, testDataFile)
val randomForestEstimator = new RandomForestClassifier()
.setLabelCol("label")
.setFeaturesCol("features")
.setNumTrees(10)
val model = randomForestEstimator.fit(trainDataFrame)
val predictions = model.transform(testDataFrame)
predictions.show(20)
val predictionLabels = predictions.select("prediction").collect() map (_.getDouble(0).toInt)
//println ( predictionLabels.take(20) mkString(",") )
val imageIdAndPredictedLabel = (1 to predictionLabels.length) zip predictionLabels
import java.io._
val pw = new PrintWriter(new File("""C:\Developer\Kaggle\DigitRecogniser\janis_output.csv"""))
pw.write("ImageId,Label\n")
imageIdAndPredictedLabel foreach ( row => pw.write(s"${row._1},${row._2}\n") )
pw.close
println(s"Predictions: ${predictions.count()}")
sparkSession.stop()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment