-
-
Save sid-sahani/c62cca0db02d491d9e16a33bb5dfbe4c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.examples.convolution; | |
import com.google.common.collect.Lists; | |
import org.datavec.api.records.reader.RecordReader; | |
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; | |
import org.datavec.api.split.FileSplit; | |
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; | |
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator; | |
import org.deeplearning4j.eval.Evaluation; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.LearningRatePolicy; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.Updater; | |
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer; | |
import org.deeplearning4j.nn.conf.layers.setup.ConvolutionLayerSetup; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import java.io.File; | |
import java.util.ArrayList; | |
import java.util.Collections; | |
public class Imp { | |
public static void main(String[] args) throws Exception { | |
//DLParamsSpecifier dlParamSpecifier = new DLParamsSpecifier(); | |
int nChannels = 1; | |
int outputNum = 2; | |
int batchSize = 20; //64 | |
int nEpochs = 100; | |
int iterations = 1; | |
int seed = 123; | |
int fullSize=408; | |
System.out.println("Load data...."); | |
RecordReader rr = new CSVRecordReader(); | |
rr.initialize(new FileSplit(new File("src\\main\\resources\\classification\\trainableData.csv"))); | |
DataSetIterator LSTrain = new RecordReaderDataSetIterator(rr,fullSize,0,2); | |
RecordReader rrTest = new CSVRecordReader(); | |
rrTest.initialize(new FileSplit(new File("src\\main\\resources\\classification\\testableData.csv"))); | |
DataSetIterator LSTest = new RecordReaderDataSetIterator(rrTest,batchSize,0,2); | |
System.out.println("Build model...."); | |
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() | |
.seed(seed) | |
.iterations(iterations) | |
.regularization(true).l2(0.0005) | |
.learningRate(0.01)//.biasLearningRate(0.02) | |
.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75) | |
.weightInit(WeightInit.XAVIER) | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
.updater(Updater.NESTEROVS).momentum(0.9) | |
.list() | |
.layer(0, new ConvolutionLayer.Builder(4,300) | |
//nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied | |
.nIn(nChannels) | |
.padding(1,1) | |
.stride(1,1) | |
.nOut(100) | |
.activation("relu") | |
.build()) | |
.layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) | |
.kernelSize(2,2) | |
.stride(1,1) | |
.build()) | |
.layer(2, new DenseLayer.Builder().activation("relu") | |
.nOut(500).dropOut(0.5).build()) | |
.layer(3, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) | |
.nOut(outputNum) | |
.activation("softmax") | |
.build()) | |
.backprop(true).pretrain(false); | |
new ConvolutionLayerSetup(builder,199,300,1); | |
MultiLayerConfiguration conf = builder.build(); | |
MultiLayerNetwork model = new MultiLayerNetwork(conf); | |
model.init(); | |
System.out.println("Train model...."); | |
model.setListeners(new ScoreIterationListener(1)); | |
for( int i=0; i<nEpochs; i++ ) { | |
System.out.println("\nTrying epoch "+ (i+1)); | |
ArrayList<DataSet> myList = Lists.newArrayList(LSTrain); | |
Collections.shuffle(myList); | |
DataSetIterator shuffledLSTrainIterator = new ListDataSetIterator(myList,batchSize); | |
while (shuffledLSTrainIterator.hasNext()) { | |
model.fit(shuffledLSTrainIterator); | |
} | |
System.out.println("*** Completed epoch {} ***"+ (i+1)); | |
System.out.println("Evaluate model...."); | |
Evaluation eval = new Evaluation(outputNum); | |
while(LSTest.hasNext()){ | |
DataSet ds = LSTest.next(); | |
INDArray output = model.output(ds.getFeatureMatrix(), false); | |
eval.eval(ds.getLabels(), output); | |
} | |
System.out.println(eval.stats()); | |
LSTest.reset(); | |
} | |
System.out.println("****************Example finished********************"); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment