Skip to content

Instantly share code, notes, and snippets.

@yptheangel
Created August 2, 2019 09:18
Show Gist options
  • Save yptheangel/9baff7a2f74d8d6844e79b58d6cd4ca7 to your computer and use it in GitHub Desktop.
Save yptheangel/9baff7a2f74d8d6844e79b58d6cd4ca7 to your computer and use it in GitHub Desktop.
package global.skymind.solution.humanactivity;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.NumberedFileInputSplit;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.conf.updater.AdamSpace;
import org.deeplearning4j.arbiter.layers.LSTMLayerSpace;
import org.deeplearning4j.arbiter.layers.RnnOutputLayerSpace;
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator;
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver;
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction;
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
import org.deeplearning4j.arbiter.optimize.api.termination.MaxTimeCondition;
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner;
import org.deeplearning4j.arbiter.saver.local.FileModelSaver;
import org.deeplearning4j.arbiter.scoring.impl.EvaluationScoreFunction;
import org.deeplearning4j.arbiter.task.ComputationGraphTaskCreator;
import org.deeplearning4j.arbiter.ui.listener.ArbiterStatusListener;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.storage.FileStatsStorage;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;
import java.util.List;
import java.util.Properties;
import java.util.concurrent.TimeUnit;
public class HyperParamTuning {
private static long seed = 12345;
private static final int numClasses =6;
private static final int numSkipLines=0;
private static int epochs =2;
public static void main(String[] args) throws Exception {
ParameterSpace<Double> learningRateHyperparam = new ContinuousParameterSpace(0.001, 0.1); //Values will be generated uniformly at random between 0.0001 and 0.1 (inclusive)
ParameterSpace<Integer> layerSizeHyperparam = new IntegerParameterSpace(5, 200); //Integer values will be generated uniformly at random between 16 and 256 (inclusive)
ComputationGraphSpace hyperparameterSpace = new ComputationGraphSpace.Builder()
.seed(seed)
.weightInit(WeightInit.XAVIER)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new AdamSpace(learningRateHyperparam))
.addInputs("trainFeatures")
.setOutputs("predictClass")
.addLayer("layer0", new LSTMLayerSpace.Builder()
.nIn(9)
.nOut(layerSizeHyperparam)
.activation(Activation.TANH)
.build(),
"trainFeatures")
.addLayer("predictClass", new RnnOutputLayerSpace.Builder()
.nIn(layerSizeHyperparam)
.nOut(numClasses)
.lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.build(),
"layer0")
.numEpochs(epochs)
.build();
CandidateGenerator candidateGenerator = new RandomSearchGenerator(hyperparameterSpace, null); //Alternatively: new GridSearchCandidateGenerator<>(hyperparameterSpace, 5, GridSearchCandidateGenerator.Mode.RandomOrder);
Class<? extends DataSource> dataSourceClass = ExampleDataSource.class;
Properties dataSourceProperties = new Properties();
dataSourceProperties.setProperty("minibatchSize", "64");
// (c) How we are going to save the models that are generated and tested?
// In this example, let's save them to disk the working directory
// This will result in examples being saved to arbiterExample/0/, arbiterExample/1/, arbiterExample/2/, ...
String baseSaveDirectory = "arbiterExample/";
File f = new File(baseSaveDirectory);
if (f.exists()) f.delete();
f.mkdir();
ResultSaver modelSaver = new FileModelSaver(baseSaveDirectory);
// ScoreFunction scoreFunction = new EvaluationScoreFunction(Evaluation.Metric.ACCURACY);
ScoreFunction scoreFunction = new EvaluationScoreFunction(Evaluation.Metric.F1);
TerminationCondition[] terminationConditions = {
new MaxTimeCondition(15, TimeUnit.MINUTES),
new MaxCandidatesCondition(20)};
//Given these configuration options, let's put them all together:
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(candidateGenerator)
.dataSource(dataSourceClass,dataSourceProperties)
.modelSaver(modelSaver)
.scoreFunction(scoreFunction)
.terminationConditions(terminationConditions)
.build();
//And set up execution locally on this machine:
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator());
File tmpFile = new File(System.getProperty("java.io.tmpdir"), "arbiterExampleUiStats.dl4j");
if (tmpFile.exists()) tmpFile.delete();
System.out.println(tmpFile);
//Start the UI. Arbiter uses the same storage and persistence approach as DL4J's UI
//Access at http://localhost:9000/arbiter
StatsStorage ss = new FileStatsStorage(new File(System.getProperty("java.io.tmpdir"), "arbiterExampleUiStats.dl4j"));
runner.addListeners(new ArbiterStatusListener(ss));
UIServer.getInstance().attach(ss);
//Start the hyperparameter optimization
runner.execute();
//Print out some basic stats regarding the optimization procedure
String s = "Best score: " + runner.bestScore() + "\n" +
"Index of model with best score: " + runner.bestScoreCandidateIndex() + "\n" +
"Number of configurations evaluated: " + runner.numCandidatesCompleted() + "\n";
System.out.println(s);
//Get all results, and print out details of the best result:
int indexOfBestResult = runner.bestScoreCandidateIndex();
List<ResultReference> allResults = runner.getResults();
OptimizationResult bestResult = allResults.get(indexOfBestResult).getResult();
ComputationGraph bestModel = (ComputationGraph) bestResult.getResultReference().getResultModel();
System.out.println("\n\nConfiguration of best model:\n");
System.out.println(bestModel.getConfiguration().toJson());
//Wait a while before exiting
Thread.sleep(60000);
UIServer.getInstance().stop();
}
//
public static class ExampleDataSource implements DataSource {
private int minibatchSize;
//
public ExampleDataSource() {
}
@Override
public void configure(Properties properties) {
this.minibatchSize = Integer.parseInt(properties.getProperty("minibatchSize", "64"));
}
@Override
public Object trainData() {
try {
File trainBaseDir = new File(System.getProperty("user.home"), ".deeplearning4j/data/humanactivity/train/");
File trainFeaturesDir = new File(trainBaseDir,"features");
File trainLabelsDir = new File(trainBaseDir,"labels");
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(numSkipLines,",");
trainFeatures.initialize(new NumberedFileInputSplit(trainFeaturesDir.getAbsolutePath()+"/%d.csv",0,7351));
SequenceRecordReader trainLabels = new CSVSequenceRecordReader(numSkipLines,",");
trainLabels.initialize(new NumberedFileInputSplit(trainLabelsDir.getAbsolutePath()+"/%d.csv",0,7351));
DataSetIterator train = new SequenceRecordReaderDataSetIterator(trainFeatures,trainLabels,minibatchSize,numClasses,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
return train;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public Object testData() {
try {
File testBaseDir = new File(System.getProperty("user.home"), ".deeplearning4j/data/humanactivity/test/");
File testFeaturesDir = new File(testBaseDir,"features");
File testLabelsDir = new File(testBaseDir,"labels");
SequenceRecordReader testFeatures = new CSVSequenceRecordReader(numSkipLines,",");
testFeatures.initialize(new NumberedFileInputSplit(testFeaturesDir.getAbsolutePath()+"/%d.csv",0,2946));
SequenceRecordReader testLabels = new CSVSequenceRecordReader(numSkipLines,",");
testLabels.initialize(new NumberedFileInputSplit(testLabelsDir.getAbsolutePath()+"/%d.csv",0,2946));
DataSetIterator test = new SequenceRecordReaderDataSetIterator(testFeatures,testLabels,minibatchSize,numClasses,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
return test;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public Class<?> getDataType() {
return DataSetIterator.class;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment