Skip to content

Instantly share code, notes, and snippets.

@shenkev
Created December 4, 2016 22:48
Show Gist options
  • Save shenkev/2d95e8aeb313dc9289ebc63a19fce1cf to your computer and use it in GitHub Desktop.
Save shenkev/2d95e8aeb313dc9289ebc63a19fce1cf to your computer and use it in GitHub Desktop.
Configuration and training file
//Initialize the user interface backend
UIServer uiServer = UIServer.getInstance();
//Configure where the network information (gradients, score vs. time etc) is to be stored. Here: store in memory.
StatsStorage statsStorage = new InMemoryStatsStorage(); //Alternative: new FileStatsStorage(File), for saving and loading later
//Attach the StatsStorage instance to the UI: this allows the contents of the StatsStorage to be visualized
uiServer.attach(statsStorage);
Object[] dat = offlineTraining.loadOfflineDat();
double[][] Xarr = (double[][])dat[0];
double[][] yarr = (double[][])dat[1];
INDArray X = Nd4j.create(Xarr);
INDArray y = Nd4j.create(yarr);
System.out.println(y);
// Network Parameters
int rngSeed = 123; // random number seed for reproducibility
final Random rng = new Random(rngSeed);
OptimizationAlgorithm algo = OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT;
int iterations = 1; //Number of iterations per minibatch
String hiddenAct = "tanh";
String outAct = "identity";
Updater updater = Updater.ADAM;
// Learning Parameters
// double rate = 4.0;
double regularize = 0.00001;
double dropOut = 0.0;
int numEpochs = 400;
int batchSize = 18;
int printEvery = Xarr.length/batchSize;
// Dimensions
int features = 13;
int lay1 = 25;
int lay2 = 50;
int lay3 = 20;
int outs = 6;
final DataSet allData = new DataSet(X,y);
final List<DataSet> list = allData.asList();
Collections.shuffle(list, rng);
DataSetIterator iterator = new ListDataSetIterator(list, batchSize);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(rngSeed)
.optimizationAlgo(algo)
.iterations(iterations)
.activation(hiddenAct)
.weightInit(WeightInit.XAVIER_UNIFORM)
.updater(updater)
.regularization(true).l2(regularize).dropOut(dropOut)
.list()
.layer(0, new DenseLayer.Builder()
.learningRate(0.01)
.adamMeanDecay(0.9)
.adamVarDecay(0.99)
.nIn(features)
.nOut(lay1)
.build())
.layer(1, new DenseLayer.Builder()
.learningRate(0.005)
.adamMeanDecay(0.9)
.adamVarDecay(0.99)
.nIn(lay1)
.nOut(lay2)
.build())
.layer(2, new DenseLayer.Builder()
.learningRate(0.005)
.adamMeanDecay(0.9)
.adamVarDecay(0.99)
.nIn(lay2)
.nOut(lay3)
.build())
.layer(3, new OutputLayer.Builder()
.activation(outAct)
.lossFunction(LossFunctions.LossFunction.L2)
.learningRate(0.01)
.adamMeanDecay(0.9)
.adamVarDecay(0.99)
.nIn(lay3)
.nOut(outs)
.build())
.pretrain(false).backprop(true)
.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(new ScoreIterationListener(printEvery));
//Then add the StatsListener to collect this information from the network, as it trains
model.setListeners(new StatsListener(statsStorage));
System.out.println("Training model");
for( int i=0; i< numEpochs; i++ ){
iterator.reset();
model.fit(iterator);
}
System.out.println(model.output(X, false));
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment