Created
June 10, 2016 19:16
-
-
Save eraly/14e0b5296936c3627fe15bb553ff8905 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public class ImageCNNDummy { | |
private static class SimplePreProcessor implements DataSetPreProcessor { | |
@Override | |
public void preProcess(org.nd4j.linalg.dataset.api.DataSet toPreProcess) { | |
toPreProcess.getFeatureMatrix().divi(255); //[0,255] -> [0,1] for input pixel values | |
} | |
} | |
// Images are of format given by allowedExtension - | |
protected static final String[] allowedExtensions = BaseImageLoader.ALLOWED_FORMATS; | |
protected static final long seed = 12345; | |
public static final Random randNumGen = new Random(seed); | |
public static void main(String[] args) throws Exception { | |
int nChannels = 1; | |
int outputNum = 3; | |
int batchSize = 64; | |
int nEpochs = 10; | |
int iterations = 1; | |
int seed = 123; | |
int width = 100; | |
int height = 100; | |
int channels = 3; | |
System.out.println("####"); | |
System.out.println("Hello World!"); | |
System.out.println("Load data...."); | |
File parentDir = new File(System.getProperty("user.dir"), "CNNTest/"); | |
FileSplit filesInDir = new FileSplit(parentDir, allowedExtensions, randNumGen); | |
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); | |
BalancedPathFilter pathFilter = new BalancedPathFilter(randNumGen, allowedExtensions, labelMaker); | |
InputSplit[] filesInDirSplit = filesInDir.sample(pathFilter, 100, 0); | |
InputSplit trainData = filesInDirSplit[0]; | |
InputSplit testData = filesInDirSplit[1]; | |
ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker); | |
recordReader.initialize(trainData); | |
DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, -1, outputNum); | |
SimplePreProcessor SP = new SimplePreProcessor(); | |
//NOTE: | |
//The normalization was not done correctly. | |
//The net is fit with an iterator. You have to set the preprocessor you want to use to iterator as below. | |
//There are in an example that explains use of the preprocessor with both datasets and the iterator. | |
//Take a looksie: | |
//https://github.com/deeplearning4j/dl4j-0.4-examples/blob/master/src/main/java/org/deeplearning4j/examples/dataExamples/PreprocessNormalizerExample.java | |
dataIter.setPreProcessor(SP); | |
//In the while loop you were create a dataset, ds and then modifying it. This does not have any impact on the iterator. | |
//and this is what gets passed to you net. | |
/*while (dataIter.hasNext()) { | |
DataSet ds = dataIter.next(); | |
SP.preProcess(ds); | |
}*/ | |
dataIter.reset(); | |
System.out.println("Build model...."); | |
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).iterations(iterations) | |
.regularization(true).l2(0.0005).learningRate(0.0000000000001)// .biasLearningRate(0.02) | |
// .learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75) | |
.weightInit(WeightInit.XAVIER).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
.updater(Updater.NESTEROVS).momentum(0.9).list() | |
.layer(0, | |
new ConvolutionLayer.Builder(5, 5).nIn(nChannels).stride(1, 1).nOut(20).activation("identity") | |
.build()) | |
.layer(1, | |
new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2) | |
.build()) | |
.layer(2, | |
new ConvolutionLayer.Builder(5, 5).nIn(nChannels).stride(1, 1).nOut(50).activation("identity") | |
.build()) | |
.layer(3, | |
new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(2, 2).stride(2, 2) | |
.build()) | |
.layer(4, new DenseLayer.Builder().activation("relu").nOut(500).build()) | |
.layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nOut(outputNum) | |
.activation("softmax").build()) | |
.backprop(true).pretrain(false); | |
new ConvolutionLayerSetup(builder, width, height, channels); | |
MultiLayerConfiguration conf = builder.build(); | |
MultiLayerNetwork model = new MultiLayerNetwork(conf); | |
model.init(); | |
System.out.println("Train model...."); | |
model.setListeners(new ScoreIterationListener(1)); | |
for (int i = 0; i < nEpochs; i++) { | |
model.fit(dataIter); | |
System.out.println("*** Completed epoch {" + i + "} ***"); | |
System.out.println("Evaluate model...."); | |
//NOTE: After the fit, you need to reset the iterator if you want to | |
//get data from it for eval | |
dataIter.reset(); | |
Evaluation eval = new Evaluation(outputNum); | |
while (dataIter.hasNext()) { | |
DataSet ds = dataIter.next(); | |
INDArray output = model.output(ds.getFeatureMatrix(), false); | |
eval.eval(ds.getLabels(), output); | |
} | |
System.out.println(eval.stats()); | |
//NOTE: You have to reset the iterator again since you want to run multiple epochs | |
dataIter.reset(); | |
} | |
System.out.println("****************Example finished********************"); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment