Skip to content

Instantly share code, notes, and snippets.

Created August 6, 2017 14:08
Show Gist options
  • Save anonymous/b288e65d1d99451c467f69efc0137e64 to your computer and use it in GitHub Desktop.
Save anonymous/b288e65d1d99451c467f69efc0137e64 to your computer and use it in GitHub Desktop.
def run(trainingDataFile: String, testDataFile: String, submissionFile: String): Unit = {
val crossValidator = new QuoraQuestionsPairsCrossValidator
logger.info(s"Cross validator params:\n${crossValidator.explainParams()}")
val numVariations = crossValidator.extractParamMap().toSeq.map(_.value.asInstanceOf[List[_]].length).product
logger.info(s"Cross validator for kaggle quora questions pairs will train $numVariations * ${crossValidator.numFolds} models")
// Train with cross validation to get the best params.
val trainData = featuresLoader.loadTrainFile(spark, trainingDataFile)
trainData.cache() // we will use this repeatedly
val cvModel = crossValidator.fit(trainData)
val bestParams = cvModel.getEstimatorParamMaps.zip(cvModel.avgMetrics).maxBy(_._2)._1
// Train on all data.
val estimator = new QuoraQuestionsPairsPipeline().copy(bestParams)
val model = estimator.fit(trainData)
val testData = featuresLoader.loadTestFile(spark, testDataFile)
val submissionWriter = new SubmissionWriter().setModel(model)
submissionWriter.writeSubmissionFile(testData, submissionFile)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment