Skip to content

Instantly share code, notes, and snippets.

Created August 6, 2017 14:01
Show Gist options
  • Save anonymous/6348dde1987df092b156f4622d119f76 to your computer and use it in GitHub Desktop.
Save anonymous/6348dde1987df092b156f4622d119f76 to your computer and use it in GitHub Desktop.
def transform(df: Dataset[_]): DataFrame = {
import df.sparkSession.implicits.newProductEncoder
val predictions = $(model).transform(df).select("p", "id").as[(DenseVector, String)]
predictions map { case (p, testId) =>
(p.values.last, testId)
} toDF("is_duplicate", "test_id")
}
def writeSubmissionFile(features: Dataset[Features], submissionFilePath: String): Unit = {
val csvOptions = Map("header" -> "true", "escape" -> "\"")
val submission = transform(features)
submission.repartition(1).write.options(csvOptions).csv(submissionFilePath)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment