Skip to content

Instantly share code, notes, and snippets.

@MaxLeiter
Created May 24, 2018 23:10
Show Gist options
  • Save MaxLeiter/a939deda099d33e175b9bd8b065336ef to your computer and use it in GitHub Desktop.
Save MaxLeiter/a939deda099d33e175b9bd8b065336ef to your computer and use it in GitHub Desktop.
package examples;
import org.datavec.api.io.filters.BalancedPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.api.util.ClassPathResource;
import org.datavec.image.loader.BaseImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.*;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.MultipleEpochsIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.conf.*;
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
public class Glaucoma {
private static final String[] allowedExtensions = BaseImageLoader.ALLOWED_FORMATS;
private static final long seed = 12345;
private static final int epochs = 10; // epochés
private static final Random randNumGen = new Random(seed);
private static final int height = 300; // image info
private static final int width = 300;
private static final int channels = 1; // RGB = 3, but we're converting to Grayscale
private static final int numLabels = 2; // Yes or No glaucoma
private static final int batchSize = 200; // Images to train per epoch
private static final int labelIndex = 1;
public static void main(String[] args) throws Exception {
File parentDir = new ClassPathResource("/glaucoma_test/").getFile();
FileSplit filesInDir = new FileSplit(parentDir, allowedExtensions, randNumGen);
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
BalancedPathFilter pathFilter = new BalancedPathFilter(randNumGen, allowedExtensions, labelMaker);
InputSplit[] filesInDirSplit = filesInDir.sample(pathFilter, 50, 50); // 80, 20
InputSplit trainData = filesInDirSplit[0];
InputSplit testData = filesInDirSplit[1];
ImageRecordReader recordReader = new ImageRecordReader(height, width, channels, labelMaker);
UIServer uiServer = UIServer.getInstance();
StatsStorage storage = new InMemoryStatsStorage();
uiServer.attach(storage);
DataSetIterator dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numLabels);
ImageTransform flipTransform1 = new FlipImageTransform();
ImageTransform flipTransform2 = new FlipImageTransform(new Random(seed));
List<ImageTransform> transforms = Arrays.asList(new ImageTransform[]{flipTransform1, flipTransform2});
MultiLayerConfiguration conf = ourNet();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
recordReader.initialize(trainData);
net.init();
net.setListeners(new StatsListener(storage), new ScoreIterationListener(10));
MultipleEpochsIterator trainIter = new MultipleEpochsIterator(epochs, dataIter); // we will train with multiple epochs
net.fit(trainIter);
for (ImageTransform t : transforms) { // re-train on every transform
System.out.print("\nTraining on transformation: " + t.getClass().toString() + "\n\n");
recordReader.initialize(trainData, t);
dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numLabels);
trainIter = new MultipleEpochsIterator(epochs, dataIter);
net.fit(trainIter);
}
recordReader.initialize(testData);
dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, labelIndex, numLabels);
Evaluation eval = net.evaluate(dataIter);
System.out.println(eval.stats(true));
}
private static MultiLayerConfiguration ourNet() {
double nonZeroBias = 1;
int[] inputShape = new int[]{channels, height, width};
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.weightInit(WeightInit.DISTRIBUTION)
.dist(new NormalDistribution(0.0, 0.01))
.activation(Activation.RELU)
.updater(Updater.NESTEROVS)
.convolutionMode(ConvolutionMode.Same)
.gradientNormalization(GradientNormalization.RenormalizeL2PerLayer) // normalize to prevent vanishing or exploding gradients
.l2(5 * 1e-4)
.miniBatch(false)
.list()
.layer(0, new ConvolutionLayer.Builder(new int[]{11, 11}, new int[]{4, 4})
.name("cnn1")
.cudnnAlgoMode(ConvolutionLayer.AlgoMode.PREFER_FASTEST)
.convolutionMode(ConvolutionMode.Truncate)
.nIn(inputShape[0])
.nOut(96)
.build())
.layer(1, new LocalResponseNormalization.Builder().build())
.layer(2, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
.kernelSize(3, 3)
.stride(2, 2)
.padding(1, 1)
.name("maxpool1")
.build())
.layer(3, new ConvolutionLayer.Builder(new int[]{5, 5}, new int[]{1, 1}, new int[]{2, 2})
.name("cnn2")
.cudnnAlgoMode(ConvolutionLayer.AlgoMode.PREFER_FASTEST)
.convolutionMode(ConvolutionMode.Truncate)
.nOut(256)
.biasInit(nonZeroBias)
.build())
.layer(4, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[]{3, 3}, new int[]{2, 2})
.convolutionMode(ConvolutionMode.Truncate)
.name("maxpool2")
.build())
.layer(5, new LocalResponseNormalization.Builder().build())
.layer(6, new ConvolutionLayer.Builder()
.kernelSize(3, 3)
.stride(1, 1)
.convolutionMode(ConvolutionMode.Same)
.name("cnn3")
.cudnnAlgoMode(ConvolutionLayer.AlgoMode.PREFER_FASTEST)
.nOut(384)
.build())
.layer(7, new ConvolutionLayer.Builder(new int[]{3, 3}, new int[]{1, 1})
.name("cnn4")
.cudnnAlgoMode(ConvolutionLayer.AlgoMode.PREFER_FASTEST)
.nOut(384)
.biasInit(nonZeroBias)
.build())
.layer(8, new ConvolutionLayer.Builder(new int[]{3, 3}, new int[]{1, 1})
.name("cnn5")
.cudnnAlgoMode(ConvolutionLayer.AlgoMode.PREFER_FASTEST)
.nOut(256)
.biasInit(nonZeroBias)
.build())
.layer(9, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX, new int[]{3, 3}, new int[]{2, 2})
.name("maxpool3")
.convolutionMode(ConvolutionMode.Truncate)
.build())
/* .layer(10, new DenseLayer.Builder()
.name("ffn1")
.nIn(256 * 6 * 6)
.nOut(4096)
.dist(new GaussianDistribution(0, 0.005))
.biasInit(nonZeroBias)
.build())*/
.layer(10, new DenseLayer.Builder()
.name("ffn2")
.nOut(4096)
.weightInit(WeightInit.DISTRIBUTION).dist(new GaussianDistribution(0, 0.005))
.biasInit(nonZeroBias)
.dropOut(0.5)
.build())
.layer(11, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.name("output")
.nOut(numLabels)
.activation(Activation.SOFTMAX)
.weightInit(WeightInit.DISTRIBUTION).dist(new GaussianDistribution(0, 0.005))
.biasInit(0.1)
.build())
.backprop(true)
.pretrain(false)
.setInputType(InputType.convolutional(inputShape[2], inputShape[1], inputShape[0]))
.build();
return conf;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment