Created
August 2, 2019 09:18
-
-
Save yptheangel/9baff7a2f74d8d6844e79b58d6cd4ca7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package global.skymind.solution.humanactivity; | |
import org.datavec.api.records.reader.SequenceRecordReader; | |
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; | |
import org.datavec.api.split.NumberedFileInputSplit; | |
import org.deeplearning4j.api.storage.StatsStorage; | |
import org.deeplearning4j.arbiter.ComputationGraphSpace; | |
import org.deeplearning4j.arbiter.conf.updater.AdamSpace; | |
import org.deeplearning4j.arbiter.layers.LSTMLayerSpace; | |
import org.deeplearning4j.arbiter.layers.RnnOutputLayerSpace; | |
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; | |
import org.deeplearning4j.arbiter.optimize.api.data.DataSource; | |
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult; | |
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; | |
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference; | |
import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver; | |
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction; | |
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition; | |
import org.deeplearning4j.arbiter.optimize.api.termination.MaxTimeCondition; | |
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition; | |
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; | |
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator; | |
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace; | |
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace; | |
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; | |
import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner; | |
import org.deeplearning4j.arbiter.saver.local.FileModelSaver; | |
import org.deeplearning4j.arbiter.scoring.impl.EvaluationScoreFunction; | |
import org.deeplearning4j.arbiter.task.ComputationGraphTaskCreator; | |
import org.deeplearning4j.arbiter.ui.listener.ArbiterStatusListener; | |
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator; | |
import org.deeplearning4j.eval.Evaluation; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.graph.ComputationGraph; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.ui.api.UIServer; | |
import org.deeplearning4j.ui.storage.FileStatsStorage; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import java.io.File; | |
import java.util.List; | |
import java.util.Properties; | |
import java.util.concurrent.TimeUnit; | |
public class HyperParamTuning { | |
private static long seed = 12345; | |
private static final int numClasses =6; | |
private static final int numSkipLines=0; | |
private static int epochs =2; | |
public static void main(String[] args) throws Exception { | |
ParameterSpace<Double> learningRateHyperparam = new ContinuousParameterSpace(0.001, 0.1); //Values will be generated uniformly at random between 0.0001 and 0.1 (inclusive) | |
ParameterSpace<Integer> layerSizeHyperparam = new IntegerParameterSpace(5, 200); //Integer values will be generated uniformly at random between 16 and 256 (inclusive) | |
ComputationGraphSpace hyperparameterSpace = new ComputationGraphSpace.Builder() | |
.seed(seed) | |
.weightInit(WeightInit.XAVIER) | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
.updater(new AdamSpace(learningRateHyperparam)) | |
.addInputs("trainFeatures") | |
.setOutputs("predictClass") | |
.addLayer("layer0", new LSTMLayerSpace.Builder() | |
.nIn(9) | |
.nOut(layerSizeHyperparam) | |
.activation(Activation.TANH) | |
.build(), | |
"trainFeatures") | |
.addLayer("predictClass", new RnnOutputLayerSpace.Builder() | |
.nIn(layerSizeHyperparam) | |
.nOut(numClasses) | |
.lossFunction(LossFunctions.LossFunction.MCXENT) | |
.activation(Activation.SOFTMAX) | |
.build(), | |
"layer0") | |
.numEpochs(epochs) | |
.build(); | |
CandidateGenerator candidateGenerator = new RandomSearchGenerator(hyperparameterSpace, null); //Alternatively: new GridSearchCandidateGenerator<>(hyperparameterSpace, 5, GridSearchCandidateGenerator.Mode.RandomOrder); | |
Class<? extends DataSource> dataSourceClass = ExampleDataSource.class; | |
Properties dataSourceProperties = new Properties(); | |
dataSourceProperties.setProperty("minibatchSize", "64"); | |
// (c) How we are going to save the models that are generated and tested? | |
// In this example, let's save them to disk the working directory | |
// This will result in examples being saved to arbiterExample/0/, arbiterExample/1/, arbiterExample/2/, ... | |
String baseSaveDirectory = "arbiterExample/"; | |
File f = new File(baseSaveDirectory); | |
if (f.exists()) f.delete(); | |
f.mkdir(); | |
ResultSaver modelSaver = new FileModelSaver(baseSaveDirectory); | |
// ScoreFunction scoreFunction = new EvaluationScoreFunction(Evaluation.Metric.ACCURACY); | |
ScoreFunction scoreFunction = new EvaluationScoreFunction(Evaluation.Metric.F1); | |
TerminationCondition[] terminationConditions = { | |
new MaxTimeCondition(15, TimeUnit.MINUTES), | |
new MaxCandidatesCondition(20)}; | |
//Given these configuration options, let's put them all together: | |
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder() | |
.candidateGenerator(candidateGenerator) | |
.dataSource(dataSourceClass,dataSourceProperties) | |
.modelSaver(modelSaver) | |
.scoreFunction(scoreFunction) | |
.terminationConditions(terminationConditions) | |
.build(); | |
//And set up execution locally on this machine: | |
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator()); | |
File tmpFile = new File(System.getProperty("java.io.tmpdir"), "arbiterExampleUiStats.dl4j"); | |
if (tmpFile.exists()) tmpFile.delete(); | |
System.out.println(tmpFile); | |
//Start the UI. Arbiter uses the same storage and persistence approach as DL4J's UI | |
//Access at http://localhost:9000/arbiter | |
StatsStorage ss = new FileStatsStorage(new File(System.getProperty("java.io.tmpdir"), "arbiterExampleUiStats.dl4j")); | |
runner.addListeners(new ArbiterStatusListener(ss)); | |
UIServer.getInstance().attach(ss); | |
//Start the hyperparameter optimization | |
runner.execute(); | |
//Print out some basic stats regarding the optimization procedure | |
String s = "Best score: " + runner.bestScore() + "\n" + | |
"Index of model with best score: " + runner.bestScoreCandidateIndex() + "\n" + | |
"Number of configurations evaluated: " + runner.numCandidatesCompleted() + "\n"; | |
System.out.println(s); | |
//Get all results, and print out details of the best result: | |
int indexOfBestResult = runner.bestScoreCandidateIndex(); | |
List<ResultReference> allResults = runner.getResults(); | |
OptimizationResult bestResult = allResults.get(indexOfBestResult).getResult(); | |
ComputationGraph bestModel = (ComputationGraph) bestResult.getResultReference().getResultModel(); | |
System.out.println("\n\nConfiguration of best model:\n"); | |
System.out.println(bestModel.getConfiguration().toJson()); | |
//Wait a while before exiting | |
Thread.sleep(60000); | |
UIServer.getInstance().stop(); | |
} | |
// | |
public static class ExampleDataSource implements DataSource { | |
private int minibatchSize; | |
// | |
public ExampleDataSource() { | |
} | |
@Override | |
public void configure(Properties properties) { | |
this.minibatchSize = Integer.parseInt(properties.getProperty("minibatchSize", "64")); | |
} | |
@Override | |
public Object trainData() { | |
try { | |
File trainBaseDir = new File(System.getProperty("user.home"), ".deeplearning4j/data/humanactivity/train/"); | |
File trainFeaturesDir = new File(trainBaseDir,"features"); | |
File trainLabelsDir = new File(trainBaseDir,"labels"); | |
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(numSkipLines,","); | |
trainFeatures.initialize(new NumberedFileInputSplit(trainFeaturesDir.getAbsolutePath()+"/%d.csv",0,7351)); | |
SequenceRecordReader trainLabels = new CSVSequenceRecordReader(numSkipLines,","); | |
trainLabels.initialize(new NumberedFileInputSplit(trainLabelsDir.getAbsolutePath()+"/%d.csv",0,7351)); | |
DataSetIterator train = new SequenceRecordReaderDataSetIterator(trainFeatures,trainLabels,minibatchSize,numClasses,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); | |
return train; | |
} catch (Exception e) { | |
throw new RuntimeException(e); | |
} | |
} | |
@Override | |
public Object testData() { | |
try { | |
File testBaseDir = new File(System.getProperty("user.home"), ".deeplearning4j/data/humanactivity/test/"); | |
File testFeaturesDir = new File(testBaseDir,"features"); | |
File testLabelsDir = new File(testBaseDir,"labels"); | |
SequenceRecordReader testFeatures = new CSVSequenceRecordReader(numSkipLines,","); | |
testFeatures.initialize(new NumberedFileInputSplit(testFeaturesDir.getAbsolutePath()+"/%d.csv",0,2946)); | |
SequenceRecordReader testLabels = new CSVSequenceRecordReader(numSkipLines,","); | |
testLabels.initialize(new NumberedFileInputSplit(testLabelsDir.getAbsolutePath()+"/%d.csv",0,2946)); | |
DataSetIterator test = new SequenceRecordReaderDataSetIterator(testFeatures,testLabels,minibatchSize,numClasses,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); | |
return test; | |
} catch (Exception e) { | |
throw new RuntimeException(e); | |
} | |
} | |
@Override | |
public Class<?> getDataType() { | |
return DataSetIterator.class; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment