Created
December 4, 2016 22:48
-
-
Save shenkev/2d95e8aeb313dc9289ebc63a19fce1cf to your computer and use it in GitHub Desktop.
Configuration and training file
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
//Initialize the user interface backend | |
UIServer uiServer = UIServer.getInstance(); | |
//Configure where the network information (gradients, score vs. time etc) is to be stored. Here: store in memory. | |
StatsStorage statsStorage = new InMemoryStatsStorage(); //Alternative: new FileStatsStorage(File), for saving and loading later | |
//Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized | |
uiServer.attach(statsStorage); | |
Object[] dat = offlineTraining.loadOfflineDat(); | |
double[][] Xarr = (double[][])dat[0]; | |
double[][] yarr = (double[][])dat[1]; | |
INDArray X = Nd4j.create(Xarr); | |
INDArray y = Nd4j.create(yarr); | |
System.out.println(y); | |
// Network Parameters | |
int rngSeed = 123; // random number seed for reproducibility | |
final Random rng = new Random(rngSeed); | |
OptimizationAlgorithm algo = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT; | |
int iterations = 1; //Number of iterations per minibatch | |
String hiddenAct = "tanh"; | |
String outAct = "identity"; | |
Updater updater = Updater.ADAM; | |
// Learning Parameters | |
// double rate = 4.0; | |
double regularize = 0.00001; | |
double dropOut = 0.0; | |
int numEpochs = 400; | |
int batchSize = 18; | |
int printEvery = Xarr.length/batchSize; | |
// Dimensions | |
int features = 13; | |
int lay1 = 25; | |
int lay2 = 50; | |
int lay3 = 20; | |
int outs = 6; | |
final DataSet allData = new DataSet(X,y); | |
final List<DataSet> list = allData.asList(); | |
Collections.shuffle(list, rng); | |
DataSetIterator iterator = new ListDataSetIterator(list, batchSize); | |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | |
.seed(rngSeed) | |
.optimizationAlgo(algo) | |
.iterations(iterations) | |
.activation(hiddenAct) | |
.weightInit(WeightInit.XAVIER_UNIFORM) | |
.updater(updater) | |
.regularization(true).l2(regularize).dropOut(dropOut) | |
.list() | |
.layer(0, new DenseLayer.Builder() | |
.learningRate(0.01) | |
.adamMeanDecay(0.9) | |
.adamVarDecay(0.99) | |
.nIn(features) | |
.nOut(lay1) | |
.build()) | |
.layer(1, new DenseLayer.Builder() | |
.learningRate(0.005) | |
.adamMeanDecay(0.9) | |
.adamVarDecay(0.99) | |
.nIn(lay1) | |
.nOut(lay2) | |
.build()) | |
.layer(2, new DenseLayer.Builder() | |
.learningRate(0.005) | |
.adamMeanDecay(0.9) | |
.adamVarDecay(0.99) | |
.nIn(lay2) | |
.nOut(lay3) | |
.build()) | |
.layer(3, new OutputLayer.Builder() | |
.activation(outAct) | |
.lossFunction(LossFunctions.LossFunction.L2) | |
.learningRate(0.01) | |
.adamMeanDecay(0.9) | |
.adamVarDecay(0.99) | |
.nIn(lay3) | |
.nOut(outs) | |
.build()) | |
.pretrain(false).backprop(true) | |
.build(); | |
MultiLayerNetwork model = new MultiLayerNetwork(conf); | |
model.init(); | |
model.setListeners(new ScoreIterationListener(printEvery)); | |
//Then add the StatsListener to collect this information from the network, as it trains | |
model.setListeners(new StatsListener(statsStorage)); | |
System.out.println("Training model"); | |
for( int i=0; i< numEpochs; i++ ){ | |
iterator.reset(); | |
model.fit(iterator); | |
} | |
System.out.println(model.output(X, false)); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment