Created
August 5, 2016 21:12
-
-
Save dan-lind/8bfbbfa65f910fdc314f50e88069e211 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package kaggle; | |
import org.datavec.api.records.reader.RecordReader; | |
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; | |
import org.datavec.api.split.FileSplit; | |
import org.deeplearning4j.arbiter.DL4JConfiguration; | |
import org.deeplearning4j.arbiter.MultiLayerSpace; | |
import org.deeplearning4j.arbiter.data.DataSetIteratorProvider; | |
import org.deeplearning4j.arbiter.layers.DenseLayerSpace; | |
import org.deeplearning4j.arbiter.layers.OutputLayerSpace; | |
import org.deeplearning4j.arbiter.optimize.api.CandidateGenerator; | |
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult; | |
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace; | |
import org.deeplearning4j.arbiter.optimize.api.data.DataProvider; | |
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference; | |
import org.deeplearning4j.arbiter.optimize.api.saving.ResultSaver; | |
import org.deeplearning4j.arbiter.optimize.api.score.ScoreFunction; | |
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition; | |
import org.deeplearning4j.arbiter.optimize.api.termination.MaxTimeCondition; | |
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition; | |
import org.deeplearning4j.arbiter.optimize.candidategenerator.RandomSearchGenerator; | |
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration; | |
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace; | |
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace; | |
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner; | |
import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner; | |
import org.deeplearning4j.arbiter.optimize.runner.listener.runner.LoggingOptimizationRunnerStatusListener; | |
import org.deeplearning4j.arbiter.optimize.ui.ArbiterUIServer; | |
import org.deeplearning4j.arbiter.optimize.ui.listener.UIOptimizationRunnerStatusListener; | |
import org.deeplearning4j.arbiter.saver.local.multilayer.LocalMultiLayerNetworkSaver; | |
import org.deeplearning4j.arbiter.scoring.multilayer.TestSetAccuracyScoreFunction; | |
import org.deeplearning4j.arbiter.task.MultiLayerNetworkTaskCreator; | |
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; | |
import org.deeplearning4j.datasets.iterator.IteratorDataSetIterator; | |
import org.deeplearning4j.datasets.iterator.MultipleEpochsIterator; | |
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.dataset.SplitTestAndTrain; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import org.nd4j.linalg.dataset.api.iterator.TestDataSetIterator; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import java.io.File; | |
import java.util.List; | |
import java.util.concurrent.TimeUnit; | |
/** | |
* Created by danlin on 2016-07-15. | |
*/ | |
public class TitanicOptimization { | |
/** | |
* This is a basic hyperparameter optimization example using Arbiter to conduct random search on two network hyperparameters. | |
* The two hyperparameters are learning rate and layer size, and the search is conducted for a simple multi-layer perceptron | |
* on MNIST data. | |
* | |
* Note that this example has a UI, but it (currently) does not start automatically. | |
* By default, the UI is accessible at http://localhost:8080/arbiter | |
* | |
* @author Alex Black | |
*/ | |
public static void main(String[] args) throws Exception { | |
//First: Set up the hyperparameter configuration space. This is like a MultiLayerConfiguration, but can have either | |
// fixed values or values to optimize, for each hyperparameter | |
ParameterSpace<Double> learningRateHyperparam = new ContinuousParameterSpace(0.01, 0.1); //Values will be generated uniformly at random between 0.0001 and 0.1 (inclusive) | |
// ParameterSpace<Integer> layerSizeHyperparam = new IntegerParameterSpace(2,12); //Integer values will be generated uniformly at random between 16 and 256 (inclusive) | |
// ParameterSpace<Integer> layerSizeHyperparam2 = new IntegerParameterSpace(2,12); //Integer values will be generated uniformly at random between 16 and 256 (inclusive) | |
final int numRows = 4; | |
final int numColumns = 1; | |
int outputNum = 3; | |
int numSamples = 150; | |
int iterations = 100; | |
int seed = 123; | |
int listenerFreq = 1; | |
MultiLayerSpace hyperparameterSpace = new MultiLayerSpace.Builder() | |
//These next few options: fixed values for all models | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
.iterations(1) | |
.regularization(true) | |
.l2(1e-4) | |
.weightInit(WeightInit.XAVIER) | |
.activation("relu") | |
//Learning rate: this is something we want to test different values for | |
.learningRate(learningRateHyperparam) | |
.addLayer( new DenseLayerSpace.Builder() | |
//Fixed values for this layer: | |
.nIn(12) //Fixed input: 28x28=784 pixels for MNIST | |
.nOut(8) | |
.build()) | |
.addLayer( new DenseLayerSpace.Builder() | |
//Fixed values for this layer: | |
.nIn(8) //Fixed input: 28x28=784 pixels for MNIST | |
.nOut(4) | |
.build()) | |
.addLayer( new OutputLayerSpace.Builder() | |
//nIn: set the same hyperparemeter as the nOut for the last layer. | |
.nIn(4) | |
//The remaining hyperparameters: fixed for the output layer | |
.nOut(2) | |
.activation("softmax") | |
.lossFunction(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) | |
.build()) | |
.pretrain(false).backprop(true).build(); | |
//Now: We need to define a few configuration options | |
// (a) How are we going to generate candidates? (random search or grid search) | |
CandidateGenerator<DL4JConfiguration> candidateGenerator = new RandomSearchGenerator<>(hyperparameterSpace); //Alternatively: new GridSearchCandidateGenerator<>(hyperparameterSpace, 5, GridSearchCandidateGenerator.Mode.RandomOrder); | |
// (b) How are going to provide data? For now, we'll use a simple built-in data provider for DataSetIterators | |
RecordReader recordReader = new CSVRecordReader(0,","); | |
recordReader.initialize(new FileSplit(new File("/Users/danlin/git/dl4j-lab/dl4j/src/main/resources/updated/train0.csv"))); | |
//Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network | |
int labelIndex = 0; //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row | |
int numClasses = 2; //2 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2 | |
int batchSize = 891; //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets) | |
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses); | |
DataSet allData = iterator.next(); | |
allData.shuffle(); | |
SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65); //Use 65% of data for training | |
DataSetIterator titanicTrain = new MultipleEpochsIterator(5, new ListDataSetIterator(testAndTrain.getTrain().asList())); | |
DataSetIterator titanicTest = new ListDataSetIterator(testAndTrain.getTest().asList()); | |
DataProvider<DataSetIterator> dataProvider = new DataSetIteratorProvider(titanicTrain, titanicTest); | |
// (c) How we are going to save the models that are generated and tested? | |
// In this example, let's save them to disk the working directory | |
// This will result in examples being saved to arbiterExample/0/, arbiterExample/1/, arbiterExample/2/, ... | |
String baseSaveDirectory = "arbiterExample/"; | |
File f = new File(baseSaveDirectory); | |
if(f.exists()) f.delete(); | |
f.mkdir(); | |
ResultSaver<DL4JConfiguration,MultiLayerNetwork,Object> modelSaver = new LocalMultiLayerNetworkSaver<>(baseSaveDirectory); | |
// (d) What are we actually trying to optimize? | |
// In this example, let's use classification accuracy on the test set | |
ScoreFunction<MultiLayerNetwork,DataSetIterator> scoreFunction = new TestSetAccuracyScoreFunction(); | |
// (e) When should we stop searching? Specify this with termination conditions | |
// For this example, we are stopping the search at 15 minutes or 20 candidates - whichever comes first | |
TerminationCondition[] terminationConditions = {new MaxTimeCondition(15, TimeUnit.MINUTES), new MaxCandidatesCondition(30)}; | |
//Given these configuration options, let's put them all together: | |
OptimizationConfiguration<DL4JConfiguration, MultiLayerNetwork, DataSetIterator, Object> configuration | |
= new OptimizationConfiguration.Builder<DL4JConfiguration, MultiLayerNetwork, DataSetIterator, Object>() | |
.candidateGenerator(candidateGenerator) | |
.dataProvider(dataProvider) | |
.modelSaver(modelSaver) | |
.scoreFunction(scoreFunction) | |
.terminationConditions(terminationConditions) | |
.build(); | |
//And set up execution locally on this machine: | |
IOptimizationRunner<DL4JConfiguration,MultiLayerNetwork,Object> runner | |
= new LocalOptimizationRunner<>(configuration, new MultiLayerNetworkTaskCreator<>()); | |
//Start the UI | |
ArbiterUIServer server = ArbiterUIServer.getInstance(); | |
runner.addListeners(new UIOptimizationRunnerStatusListener(server)); | |
//Start the hyperparameter optimization | |
runner.execute(); | |
//Print out some basic stats regarding the optimization procedure | |
StringBuilder sb = new StringBuilder(); | |
sb.append("Best score: ").append(runner.bestScore()).append("\n") | |
.append("Index of model with best score: ").append(runner.bestScoreCandidateIndex()).append("\n") | |
.append("Number of configurations evaluated: ").append(runner.numCandidatesCompleted()).append("\n"); | |
System.out.println(sb.toString()); | |
//Get all results, and print out details of the best result: | |
int indexOfBestResult = runner.bestScoreCandidateIndex(); | |
List<ResultReference<DL4JConfiguration,MultiLayerNetwork,Object>> allResults = runner.getResults(); | |
OptimizationResult<DL4JConfiguration,MultiLayerNetwork,Object> bestResult = allResults.get(indexOfBestResult).getResult(); | |
MultiLayerNetwork bestModel = bestResult.getResult(); | |
System.out.println("\n\nConfiguration of best model:\n"); | |
System.out.println(bestModel.getLayerWiseConfigurations().toJson()); | |
//Note: UI server will shut down once execution is complete, as JVM will exit | |
//So do a Thread.sleep(1 minute) to keep JVM alive, so that network configurations can be viewed | |
Thread.sleep(60000); | |
System.exit(0); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment