Skip to content

Instantly share code, notes, and snippets.

@HGuillemet
Last active September 4, 2018 14:13
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save HGuillemet/7806b31234b84f475aebc4833edea79d to your computer and use it in GitHub Desktop.
Save HGuillemet/7806b31234b84f475aebc4833edea79d to your computer and use it in GitHub Desktop.
package fr.apteryx.pve;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.datavec.api.conf.Configuration;
import org.datavec.api.records.reader.BaseRecordReader;
import org.datavec.api.records.metadata.RecordMetaData;
import org.datavec.api.records.Record;
import org.datavec.api.split.InputSplit;
import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.Writable;
import org.datavec.api.writable.IntWritable;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.conf.layers.Subsampling3DLayer;
import org.deeplearning4j.nn.conf.layers.LocalResponseNormalization;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.net.URI;
import java.io.DataInputStream;
import java.io.IOException;
import java.io.File;
import java.util.List;
import java.util.ArrayList;
public class Test {
private static final int nX = 32;
private static final int nY = 32;
private static final int nZ = 28;
private static final int batchSize = 10;
private static final int epochs = 100;
// Hyperparams
private static final double L2_REGUL = 0.005;
private static final double LEARNING_RATE = 0.0001;
private static final int KS = 3; // Kernel Size
private static final double DROPOUT = 0.5;
static private MultiLayerNetwork model() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.l2(L2_REGUL)
.activation(Activation.RELU)
.weightInit(WeightInit.XAVIER)
.updater(new Adam(LEARNING_RATE))
.list()
.layer(0, new Convolution3D.Builder()
.kernelSize(KS, KS, KS)
.stride(1, 1, 1)
.padding(0, 0, 0)
.name("cnn1")
.nIn(1) // For ccn == num of channels
.nOut(10) // nom of filters
.biasInit(0)
.build()
)
.layer(1, new Convolution3D.Builder()
.kernelSize(KS, KS, KS)
.stride(1, 1, 1)
.padding(0, 0, 0)
.name("cnn2")
// .nIn(10) // implicit
.nOut(20)
.biasInit(1.0)
.build()
)
.layer(2, new Subsampling3DLayer.Builder(Subsampling3DLayer.PoolingType.MAX)
.name("pool1")
.kernelSize(2, 2, 2)
.stride(2, 2, 2)
.build())
.layer(3, new Convolution3D.Builder()
.kernelSize(KS, KS, KS)
.stride(1, 1, 1)
.padding(0, 0, 0)
.name("cnn3")
.nOut(30)
.biasInit(1.0)
.build()
)
.layer(4, new Subsampling3DLayer.Builder(Subsampling3DLayer.PoolingType.MAX)
.name("pool2")
.kernelSize(2, 2, 2)
.stride(2, 2, 2)
.build())
.layer(5, new DenseLayer.Builder()
.name("fc")
.nOut(4096)
.biasInit(1.0)
.dropOut(DROPOUT)
.dist(new GaussianDistribution(0, 0.005))
.weightInit(WeightInit.DISTRIBUTION)
.build())
.layer(6, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.name("output")
.nOut(2)
.activation(Activation.SOFTMAX)
.build())
.backprop(true) // supervisé
.pretrain(false)
.setInputType(InputType.convolutional3D(nX, nY, nZ, 1))
.build();
return new MultiLayerNetwork(conf);
}
public static void main(String[] a) throws IOException {
MultiLayerNetwork network = model();
network.init();
CustomRecordReader recordReader = new CustomRecordReader();
DataSetIterator dataIter;
dataIter = new RecordReaderDataSetIterator(recordReader, batchSize,
1, /* Index of label in records */
2 /* number of different labels */);
network.fit(dataIter, epochs);
}
static class CustomRecordReader extends BaseRecordReader {
int n = 0;
CustomRecordReader() {
}
@Override
public boolean batchesSupported() {
return false;
}
@Override
public List<List<Writable>> next(int num) {
throw new RuntimeException("Not implemented");
}
@Override
public List<Writable> next() {
INDArray nd = Nd4j.create(new float[nZ*nY*nX], new int[] { nZ, nY, nX }, 'C');
final List<Writable>res = RecordConverter.toRecord(nd);
res.add(new IntWritable(0));
n++;
return res;
}
@Override
public boolean hasNext() {
return n<10;
}
final static ArrayList<String> labels = new ArrayList<>(2);
static {
labels.add("lbl0");
labels.add("lbl1");
}
@Override
public List<String> getLabels() {
return labels;
}
@Override
public void reset() {
n = 0;
}
@Override
public boolean resetSupported() {
return true;
}
@Override
public List<Writable> record(URI uri, DataInputStream dataInputStream) {
return next();
}
@Override
public Record nextRecord() {
List<Writable> r = next();
return new org.datavec.api.records.impl.Record(r, null);
}
@Override
public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException {
throw new RuntimeException("Not implemented");
}
@Override
public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) {
throw new RuntimeException("Not implemented");
}
@Override
public void close() {
}
@Override
public void setConf(Configuration conf) {
}
@Override
public Configuration getConf() {
return null;
}
@Override
public void initialize(InputSplit split) {
n = 0;
}
@Override
public void initialize(Configuration conf, InputSplit split) {
n = 0;
}
}
}
Loaded [CpuBackend] backend
Number of threads used for NativeOps: 4
Number of threads used for BLAS: 4
Backend used: [CPU]; OS: [Linux]
Cores: [8]; Memory: [1,9GB];
Blas vendor: [OPENBLAS]
Starting MultiLayerNetwork with WorkspaceModes set to [training: ENABLED; inference: ENABLED], cacheMode set to [NONE]
Steps: 4
Exception in thread "ADSI prefetch thread" java.lang.RuntimeException: java.lang.RuntimeException: Error parsing data (writables) from record readers
at org.deeplearning4j.datasets.iterator.AsyncDataSetIterator$AsyncPrefetchThread.run(AsyncDataSetIterator.java:430)
Caused by: java.lang.RuntimeException: Error parsing data (writables) from record readers
at org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator.convertWritables(RecordReaderMultiDataSetIterator.java:459)
at org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator.convertFeaturesOrLabels(RecordReaderMultiDataSetIterator.java:363)
at org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator.nextMultiDataSet(RecordReaderMultiDataSetIterator.java:326)
at org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator.next(RecordReaderMultiDataSetIterator.java:212)
at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:364)
at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:439)
at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:84)
at org.deeplearning4j.datasets.iterator.AsyncDataSetIterator$AsyncPrefetchThread.run(AsyncDataSetIterator.java:404)
Caused by: org.nd4j.linalg.exception.ND4JIllegalStateException: X, Y and Z arguments should have the same length for PairwiseTransform. x: length 1024, shape [32, 32]; y: 28672, shape [28, 32, 32]; z: 1024, shape [32, 32]
at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.exec(NativeOpExecutioner.java:771)
at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.exec(NativeOpExecutioner.java:113)
at org.nd4j.linalg.api.ndarray.BaseNDArray.assign(BaseNDArray.java:1390)
at org.nd4j.linalg.api.ndarray.BaseNDArray.put(BaseNDArray.java:2476)
at org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator.putExample(RecordReaderMultiDataSetIterator.java:548)
at org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator.convertWritablesHelper(RecordReaderMultiDataSetIterator.java:515)
at org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator.convertWritables(RecordReaderMultiDataSetIterator.java:453)
... 7 more
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment