Skip to content

Instantly share code, notes, and snippets.

@sid-sahani
Created September 29, 2016 10:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save sid-sahani/c62cca0db02d491d9e16a33bb5dfbe4c to your computer and use it in GitHub Desktop.
Save sid-sahani/c62cca0db02d491d9e16a33bb5dfbe4c to your computer and use it in GitHub Desktop.
package org.deeplearning4j.examples.convolution;
import com.google.common.collect.Lists;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.LearningRatePolicy;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.layers.setup.ConvolutionLayerSetup;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;
import java.util.ArrayList;
import java.util.Collections;
public class Imp {
public static void main(String[] args) throws Exception {
//DLParamsSpecifier dlParamSpecifier = new DLParamsSpecifier();
int nChannels = 1;
int outputNum = 2;
int batchSize = 20; //64
int nEpochs = 100;
int iterations = 1;
int seed = 123;
int fullSize=408;
System.out.println("Load data....");
RecordReader rr = new CSVRecordReader();
rr.initialize(new FileSplit(new File("src\\main\\resources\\classification\\trainableData.csv")));
DataSetIterator LSTrain = new RecordReaderDataSetIterator(rr,fullSize,0,2);
RecordReader rrTest = new CSVRecordReader();
rrTest.initialize(new FileSplit(new File("src\\main\\resources\\classification\\testableData.csv")));
DataSetIterator LSTest = new RecordReaderDataSetIterator(rrTest,batchSize,0,2);
System.out.println("Build model....");
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
.seed(seed)
.iterations(iterations)
.regularization(true).l2(0.0005)
.learningRate(0.01)//.biasLearningRate(0.02)
.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75)
.weightInit(WeightInit.XAVIER)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(Updater.NESTEROVS).momentum(0.9)
.list()
.layer(0, new ConvolutionLayer.Builder(4,300)
//nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied
.nIn(nChannels)
.padding(1,1)
.stride(1,1)
.nOut(100)
.activation("relu")
.build())
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(2,2)
.stride(1,1)
.build())
.layer(2, new DenseLayer.Builder().activation("relu")
.nOut(500).dropOut(0.5).build())
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nOut(outputNum)
.activation("softmax")
.build())
.backprop(true).pretrain(false);
new ConvolutionLayerSetup(builder,199,300,1);
MultiLayerConfiguration conf = builder.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
System.out.println("Train model....");
model.setListeners(new ScoreIterationListener(1));
for( int i=0; i<nEpochs; i++ ) {
System.out.println("\nTrying epoch "+ (i+1));
ArrayList<DataSet> myList = Lists.newArrayList(LSTrain);
Collections.shuffle(myList);
DataSetIterator shuffledLSTrainIterator = new ListDataSetIterator(myList,batchSize);
while (shuffledLSTrainIterator.hasNext()) {
model.fit(shuffledLSTrainIterator);
}
System.out.println("*** Completed epoch {} ***"+ (i+1));
System.out.println("Evaluate model....");
Evaluation eval = new Evaluation(outputNum);
while(LSTest.hasNext()){
DataSet ds = LSTest.next();
INDArray output = model.output(ds.getFeatureMatrix(), false);
eval.eval(ds.getLabels(), output);
}
System.out.println(eval.stats());
LSTest.reset();
}
System.out.println("****************Example finished********************");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment