Created
August 6, 2017 22:06
-
-
Save anonymous/8c10042c4b94c881926f22ca55bedf1b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
private def assembleCrossValidator(): CrossValidator = { | |
val stopwordsRemover = new StopWordsRemover() | |
val countVectorizer = new CountVectorizer() | |
val logisticRegression = new LogisticRegression() | |
val lda = new LDA() | |
val estimator = new QuoraQuestionsPairsPipeline() | |
.setStopwordsRemover(stopwordsRemover) | |
.setCountVectorizer(countVectorizer) | |
.setLogisticRegression(logisticRegression) | |
.setLDA(lda) | |
// Grid search on hyperparameter space | |
val stopwordsLists = $(stopwords).map(Source.fromFile).map(_.getLines().mkString(",").split(",")) | |
val paramGrid = new ParamGridBuilder() | |
.addGrid(stopwordsRemover.stopWords, stopwordsLists | |
.addGrid(countVectorizer.vocabSize, $(vocabularySize)) | |
.addGrid(lda.k, $(numTopics)) | |
.addGrid(countVectorizer.minDF, $(minDF)) | |
.addGrid(lda.maxIter, $(ldaMaxIter)) | |
.addGrid(logisticRegression.maxIter, $(logisticRegressionMaxIter)) | |
.build() | |
val evaluator = new LogLossBinaryClassificationEvaluator() | |
.setLabelCol("isDuplicateLabel").setProbabilityCol("p") | |
// Cross-validation setup | |
new CrossValidator() | |
.setEstimator(estimator) | |
.setEvaluator(evaluator) | |
.setEstimatorParamMaps(paramGrid) | |
.setNumFolds(numFolds) | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment