Skip to content

Instantly share code, notes, and snippets.

Created August 6, 2017 22:06
Show Gist options
  • Save anonymous/8c10042c4b94c881926f22ca55bedf1b to your computer and use it in GitHub Desktop.
Save anonymous/8c10042c4b94c881926f22ca55bedf1b to your computer and use it in GitHub Desktop.
private def assembleCrossValidator(): CrossValidator = {
val stopwordsRemover = new StopWordsRemover()
val countVectorizer = new CountVectorizer()
val logisticRegression = new LogisticRegression()
val lda = new LDA()
val estimator = new QuoraQuestionsPairsPipeline()
.setStopwordsRemover(stopwordsRemover)
.setCountVectorizer(countVectorizer)
.setLogisticRegression(logisticRegression)
.setLDA(lda)
// Grid search on hyperparameter space
val stopwordsLists = $(stopwords).map(Source.fromFile).map(_.getLines().mkString(",").split(","))
val paramGrid = new ParamGridBuilder()
.addGrid(stopwordsRemover.stopWords, stopwordsLists
.addGrid(countVectorizer.vocabSize, $(vocabularySize))
.addGrid(lda.k, $(numTopics))
.addGrid(countVectorizer.minDF, $(minDF))
.addGrid(lda.maxIter, $(ldaMaxIter))
.addGrid(logisticRegression.maxIter, $(logisticRegressionMaxIter))
.build()
val evaluator = new LogLossBinaryClassificationEvaluator()
.setLabelCol("isDuplicateLabel").setProbabilityCol("p")
// Cross-validation setup
new CrossValidator()
.setEstimator(estimator)
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(numFolds)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment