Skip to content

Instantly share code, notes, and snippets.

@junyongyou
Last active May 30, 2016 14:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save junyongyou/28cbca140337d1e054f05d61f7240084 to your computer and use it in GitHub Desktop.
Save junyongyou/28cbca140337d1e054f05d61f7240084 to your computer and use it in GitHub Desktop.
package org.deeplearning4j.examples.convolution;
import org.canova.api.records.reader.RecordReader;
import org.canova.api.split.FileSplit;
import org.canova.image.loader.BaseImageLoader;
import org.canova.image.loader.NativeImageLoader;
import org.canova.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.canova.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.GradientNormalization;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.Updater;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;
import java.util.Arrays;
import java.util.List;
/**
* Created by junyong on 30.05.2016.
*/
public class ImageClassificationExample
{
public static void main(String[] args) throws Exception {
int imageWidth = 600;
int imageHeight = 600;
int channels = 3;
// create dataset
// Directory which has 1 sub-directory with images for each category you have.
File directory = new File("D:\\training_set");
int batchSize = 100;
boolean appendLabels = true;
List<String> labels = Arrays.asList(directory.list());
int numLabels = labels.size();
RecordReader recordReader = new ImageRecordReader(imageHeight, imageWidth, channels, appendLabels);
recordReader.initialize(new FileSplit(directory, BaseImageLoader.ALLOWED_FORMATS));
DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, -1, numLabels);
// setup model
MultiLayerNetwork model = new MultiLayerNetwork(buildConfig(imageWidth, imageHeight, channels, numLabels));
model.init();
model.setListeners(new ScoreIterationListener(1));
// train model
int epochs = 5;
for (int i = 0; i < epochs; i++) {
dataIter.reset();
model.fit(dataIter);
}
// test model
dataIter.reset();
Evaluation eval = new Evaluation(dataIter.getLabels());
while (dataIter.hasNext()) {
DataSet next = dataIter.next();
INDArray prediction = model.output(next.getFeatureMatrix());
eval.eval(next.getLabels(), prediction);
}
System.out.println(eval.stats());
// predict new image
File imageToPredict = new File("D:\\classified images\\100079.jpg");
NativeImageLoader imageLoader = new NativeImageLoader(imageHeight, imageWidth, channels);
INDArray imageVector = imageLoader.asRowVector(imageToPredict);
INDArray prediction = model.output(imageVector);
System.out.println("done");
// prediction contains one float for every label you have, sums up to 1
}
static private MultiLayerConfiguration buildConfig(int imageWidth, int imageHeight, int channels, int numOfClasses) {
int seed = 123;
int iterations = 1;
WeightInit weightInit = WeightInit.XAVIER;
String activation = "relu";
Updater updater = Updater.NESTEROVS;
double lr = 1e-3;
double mu = 0.9;
double l2 = 5e-4;
boolean regularization = true;
SubsamplingLayer.PoolingType poolingType = SubsamplingLayer.PoolingType.MAX;
double nonZeroBias = 1;
double dropOut = 0.5;
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder().seed(seed).iterations(iterations).activation(activation).weightInit(weightInit)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).learningRate(lr).momentum(mu)
.regularization(regularization).l2(l2).updater(updater).useDropConnect(true)
// AlexNet
.list().layer(0, new ConvolutionLayer.Builder(new int[] { 11, 11 }, new int[] { 4, 4 }, new int[] { 3, 3 }).name("cnn1").nIn(channels).nOut(96).build())
.layer(1, new LocalResponseNormalization.Builder().name("lrn1").build())
.layer(2, new SubsamplingLayer.Builder(poolingType, new int[] { 3, 3 }, new int[] { 2, 2 }).name("maxpool1").build())
.layer(3, new ConvolutionLayer.Builder(new int[] { 5, 5 }, new int[] { 1, 1 }, new int[] { 2, 2 }).name("cnn2").nOut(256).biasInit(nonZeroBias).build())
.layer(4, new LocalResponseNormalization.Builder().name("lrn2").k(2).n(5).alpha(1e-4).beta(0.75).build())
.layer(5, new SubsamplingLayer.Builder(poolingType, new int[] { 3, 3 }, new int[] { 2, 2 }).name("maxpool2").build())
.layer(6, new ConvolutionLayer.Builder(new int[] { 3, 3 }, new int[] { 1, 1 }, new int[] { 1, 1 }).name("cnn3").nOut(384).build())
.layer(7, new ConvolutionLayer.Builder(new int[] { 3, 3 }, new int[] { 1, 1 }, new int[] { 1, 1 }).name("cnn4").nOut(384).biasInit(nonZeroBias).build())
.layer(8, new ConvolutionLayer.Builder(new int[] { 3, 3 }, new int[] { 1, 1 }, new int[] { 1, 1 }).name("cnn5").nOut(256).biasInit(nonZeroBias).build())
.layer(9, new SubsamplingLayer.Builder(poolingType, new int[] { 3, 3 }, new int[] { 2, 2 }).name("maxpool3").build())
.layer(10, new DenseLayer.Builder().name("ffn1").nOut(4096).biasInit(nonZeroBias).dropOut(dropOut).build())
.layer(11, new DenseLayer.Builder().name("ffn2").nOut(4096).biasInit(nonZeroBias).dropOut(dropOut).build())
.layer(12, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).name("output").nOut(numOfClasses).activation("softmax").build()).backprop(true).pretrain(false)
.cnnInputSize(imageHeight, imageWidth, channels);
return builder.build();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment