Skip to content

Instantly share code, notes, and snippets.

@zaleslaw
Created October 19, 2020 14:45
Show Gist options
  • Save zaleslaw/1ed7dc6f0f323da32cc97fadfbb6c309 to your computer and use it in GitHub Desktop.
Save zaleslaw/1ed7dc6f0f323da32cc97fadfbb6c309 to your computer and use it in GitHub Desktop.
brute_force_parallelism
CrossValidation<DecisionTreeNode, Integer, Vector> scoreCalculator
= new CrossValidation<>();
ParamGrid paramGrid = new ParamGrid()
.withParameterSearchStrategy(new BruteForceStrategy())
.addHyperParam("p", normalizationTrainer::withP, new Double[] {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0})
.addHyperParam("maxDeep", trainerCV::withMaxDeep, new Double[] {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0})
.addHyperParam("minImpurityDecrease", trainerCV::withMinImpurityDecrease, new Double[] {0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0});
scoreCalculator
.withIgnite(ignite)
.withUpstreamCache(dataCache)
.withEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder()
.withParallelismStrategyTypeDependency(ParallelismStrategy.ON_DEFAULT_POOL) // <======= set up the ParallelismStrategy
.withLoggingFactoryDependency(ConsoleLogger.Factory.LOW))
.withTrainer(trainerCV)
.isRunningOnPipeline(false)
.withMetric(MetricName.ACCURACY)
.withFilter(split.getTrainFilter())
.withPreprocessor(normalizationPreprocessor)
.withAmountOfFolds(3)
.withParamGrid(paramGrid)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment