Skip to content

Instantly share code, notes, and snippets.

@eraly
Created June 10, 2016 19:16
Show Gist options
  • Save eraly/14e0b5296936c3627fe15bb553ff8905 to your computer and use it in GitHub Desktop.
Save eraly/14e0b5296936c3627fe15bb553ff8905 to your computer and use it in GitHub Desktop.
public class ImageCNNDummy {
private static class SimplePreProcessor implements DataSetPreProcessor {
@Override
public void preProcess(org.nd4j.linalg.dataset.api.DataSet toPreProcess) {
toPreProcess.getFeatureMatrix().divi(255); //[0,255] -> [0,1] for input pixel values
}
}
// Images are of format given by allowedExtension -
protected static final String[] allowedExtensions = BaseImageLoader.ALLOWED_FORMATS;
protected static final long seed = 12345;
public static final Random randNumGen = new Random(seed);
public static void main(String[] args) throws Exception {
int nChannels = 1;
int outputNum = 3;
int batchSize = 64;
int nEpochs = 10;
int iterations = 1;
int seed = 123;
int width = 100;
int height = 100;
int channels = 3;
System.out.println("####");
System.out.println("Hello World!");
System.out.println("Load data....");
File parentDir = new File(System.getProperty("user.dir"), "CNNTest/");
FileSplit filesInDir = new FileSplit(parentDir, allowedExtensions, randNumGen);
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
BalancedPathFilter pathFilter = new BalancedPathFilter(randNumGen, allowedExtensions, labelMaker);
InputSplit[] filesInDirSplit = filesInDir.sample(pathFilter, 100, 0);
InputSplit trainData = filesInDirSplit[0];
InputSplit testData = filesInDirSplit[1];
ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker);
recordReader.initialize(trainData);
DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, -1, outputNum);
SimplePreProcessor SP = new SimplePreProcessor();
//NOTE:
//The normalization was not done correctly.
//The net is fit with an iterator. You have to set the preprocessor you want to use to iterator as below.
//There are in an example that explains use of the preprocessor with both datasets and the iterator.
//Take a looksie:
//https://github.com/deeplearning4j/dl4j-0.4-examples/blob/master/src/main/java/org/deeplearning4j/examples/dataExamples/PreprocessNormalizerExample.java
dataIter.setPreProcessor(SP);
//In the while loop you were create a dataset, ds and then modifying it. This does not have any impact on the iterator.
//and this is what gets passed to you net.
/*while (dataIter.hasNext()) {
DataSet ds = dataIter.next();
SP.preProcess(ds);
}*/
dataIter.reset();
System.out.println("Build model....");
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).iterations(iterations)
.regularization(true).l2(0.0005).learningRate(0.0000000000001)// .biasLearningRate(0.02)
// .learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75)
.weightInit(WeightInit.XAVIER).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(Updater.NESTEROVS).momentum(0.9).list()
.layer(0,
new ConvolutionLayer.Builder(5, 5).nIn(nChannels).stride(1, 1).nOut(20).activation("identity")
.build())
.layer(1,
new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2)
.build())
.layer(2,
new ConvolutionLayer.Builder(5, 5).nIn(nChannels).stride(1, 1).nOut(50).activation("identity")
.build())
.layer(3,
new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2)
.build())
.layer(4, new DenseLayer.Builder().activation("relu").nOut(500).build())
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum)
.activation("softmax").build())
.backprop(true).pretrain(false);
new ConvolutionLayerSetup(builder, width, height, channels);
MultiLayerConfiguration conf = builder.build();
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
System.out.println("Train model....");
model.setListeners(new ScoreIterationListener(1));
for (int i = 0; i < nEpochs; i++) {
model.fit(dataIter);
System.out.println("*** Completed epoch {" + i + "} ***");
System.out.println("Evaluate model....");
//NOTE: After the fit, you need to reset the iterator if you want to
//get data from it for eval
dataIter.reset();
Evaluation eval = new Evaluation(outputNum);
while (dataIter.hasNext()) {
DataSet ds = dataIter.next();
INDArray output = model.output(ds.getFeatureMatrix(), false);
eval.eval(ds.getLabels(), output);
}
System.out.println(eval.stats());
//NOTE: You have to reset the iterator again since you want to run multiple epochs
dataIter.reset();
}
System.out.println("****************Example finished********************");
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment