Created
August 18, 2016 18:24
Star
You must be signed in to star a gist
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.log4j.{Level, Logger} | |
import org.apache.spark.ml.classification.RandomForestClassifier | |
import org.apache.spark.ml.linalg.Vectors | |
import org.apache.spark.sql.{DataFrame, SparkSession} | |
/** | |
* Created by Janis Rumnieks on 15/08/2016. | |
*/ | |
object DigitRecognizer4 { | |
def labelAndFeaturesFromCsv(sparkSession: SparkSession, csvPath: String): DataFrame = { | |
val csvLines = scala.io.Source.fromFile(csvPath).getLines() map (_.split(",") ) | |
val csvHeader = csvLines.next() | |
val csvRows = csvLines map (row => (row map (elem => elem.toDouble))) toList | |
val csvTuplesWithVectors = csvRows map (row => (row.head, Vectors.dense(row.tail) )) | |
val dataFrame = sparkSession.createDataFrame(csvTuplesWithVectors).toDF("label","features") | |
dataFrame | |
} | |
def justFeaturesFromCsv(sparkSession: SparkSession, csvPath: String): DataFrame = { | |
val csvLines = scala.io.Source.fromFile(csvPath).getLines() map (_.split(",") ) | |
val csvHeader = csvLines.next() | |
val csvRows = csvLines map (row => (row map (elem => elem.toDouble))) toList | |
val csvTuplesWithVectors = csvRows map (row => (0.0, Vectors.dense(row) )) | |
val dataFrame = sparkSession.createDataFrame( csvTuplesWithVectors ).toDF("label","features") | |
dataFrame | |
} | |
def main(args: Array[String]): Unit = { | |
//val trainDataFile = """C:\Developer\Kaggle\DigitRecogniser\train.csv""" | |
val trainDataFile = """C:\Developer\Kaggle\DigitRecogniser\train_small_3.csv""" | |
val testDataFile = """C:\Developer\Kaggle\DigitRecogniser\test.csv""" | |
// disable INFO messages | |
Logger.getLogger("org").setLevel(Level.ERROR) | |
Logger.getLogger("akka").setLevel(Level.ERROR) | |
val sparkSession = SparkSession | |
.builder() | |
.appName("Spark SQL - Digit Recognition") | |
.master("local[*]") | |
.config("spark.sql.warehouse.dir", ".") | |
.getOrCreate() | |
val trainDataFrame: DataFrame = labelAndFeaturesFromCsv(sparkSession, trainDataFile) | |
val testDataFrame: DataFrame = justFeaturesFromCsv(sparkSession, testDataFile) | |
val randomForestEstimator = new RandomForestClassifier() | |
.setLabelCol("label") | |
.setFeaturesCol("features") | |
.setNumTrees(10) | |
val model = randomForestEstimator.fit(trainDataFrame) | |
val predictions = model.transform(testDataFrame) | |
predictions.show(20) | |
val predictionLabels = predictions.select("prediction").collect() map (_.getDouble(0).toInt) | |
//println ( predictionLabels.take(20) mkString(",") ) | |
val imageIdAndPredictedLabel = (1 to predictionLabels.length) zip predictionLabels | |
import java.io._ | |
val pw = new PrintWriter(new File("""C:\Developer\Kaggle\DigitRecogniser\janis_output.csv""")) | |
pw.write("ImageId,Label\n") | |
imageIdAndPredictedLabel foreach ( row => pw.write(s"${row._1},${row._2}\n") ) | |
pw.close | |
println(s"Predictions: ${predictions.count()}") | |
sparkSession.stop() | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment