Last active
September 4, 2018 14:13
-
-
Save HGuillemet/7806b31234b84f475aebc4833edea79d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package fr.apteryx.pve; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import org.nd4j.linalg.learning.config.Adam; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import org.datavec.api.conf.Configuration; | |
import org.datavec.api.records.reader.BaseRecordReader; | |
import org.datavec.api.records.metadata.RecordMetaData; | |
import org.datavec.api.records.Record; | |
import org.datavec.api.split.InputSplit; | |
import org.datavec.api.util.ndarray.RecordConverter; | |
import org.datavec.api.writable.Writable; | |
import org.datavec.api.writable.IntWritable; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.distribution.GaussianDistribution; | |
import org.deeplearning4j.nn.conf.inputs.InputType; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.Convolution3D; | |
import org.deeplearning4j.nn.conf.layers.Subsampling3DLayer; | |
import org.deeplearning4j.nn.conf.layers.LocalResponseNormalization; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import java.net.URI; | |
import java.io.DataInputStream; | |
import java.io.IOException; | |
import java.io.File; | |
import java.util.List; | |
import java.util.ArrayList; | |
public class Test { | |
private static final int nX = 32; | |
private static final int nY = 32; | |
private static final int nZ = 28; | |
private static final int batchSize = 10; | |
private static final int epochs = 100; | |
// Hyperparams | |
private static final double L2_REGUL = 0.005; | |
private static final double LEARNING_RATE = 0.0001; | |
private static final int KS = 3; // Kernel Size | |
private static final double DROPOUT = 0.5; | |
static private MultiLayerNetwork model() { | |
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() | |
.l2(L2_REGUL) | |
.activation(Activation.RELU) | |
.weightInit(WeightInit.XAVIER) | |
.updater(new Adam(LEARNING_RATE)) | |
.list() | |
.layer(0, new Convolution3D.Builder() | |
.kernelSize(KS, KS, KS) | |
.stride(1, 1, 1) | |
.padding(0, 0, 0) | |
.name("cnn1") | |
.nIn(1) // For ccn == num of channels | |
.nOut(10) // nom of filters | |
.biasInit(0) | |
.build() | |
) | |
.layer(1, new Convolution3D.Builder() | |
.kernelSize(KS, KS, KS) | |
.stride(1, 1, 1) | |
.padding(0, 0, 0) | |
.name("cnn2") | |
// .nIn(10) // implicit | |
.nOut(20) | |
.biasInit(1.0) | |
.build() | |
) | |
.layer(2, new Subsampling3DLayer.Builder(Subsampling3DLayer.PoolingType.MAX) | |
.name("pool1") | |
.kernelSize(2, 2, 2) | |
.stride(2, 2, 2) | |
.build()) | |
.layer(3, new Convolution3D.Builder() | |
.kernelSize(KS, KS, KS) | |
.stride(1, 1, 1) | |
.padding(0, 0, 0) | |
.name("cnn3") | |
.nOut(30) | |
.biasInit(1.0) | |
.build() | |
) | |
.layer(4, new Subsampling3DLayer.Builder(Subsampling3DLayer.PoolingType.MAX) | |
.name("pool2") | |
.kernelSize(2, 2, 2) | |
.stride(2, 2, 2) | |
.build()) | |
.layer(5, new DenseLayer.Builder() | |
.name("fc") | |
.nOut(4096) | |
.biasInit(1.0) | |
.dropOut(DROPOUT) | |
.dist(new GaussianDistribution(0, 0.005)) | |
.weightInit(WeightInit.DISTRIBUTION) | |
.build()) | |
.layer(6, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) | |
.name("output") | |
.nOut(2) | |
.activation(Activation.SOFTMAX) | |
.build()) | |
.backprop(true) // supervisé | |
.pretrain(false) | |
.setInputType(InputType.convolutional3D(nX, nY, nZ, 1)) | |
.build(); | |
return new MultiLayerNetwork(conf); | |
} | |
public static void main(String[] a) throws IOException { | |
MultiLayerNetwork network = model(); | |
network.init(); | |
CustomRecordReader recordReader = new CustomRecordReader(); | |
DataSetIterator dataIter; | |
dataIter = new RecordReaderDataSetIterator(recordReader, batchSize, | |
1, /* Index of label in records */ | |
2 /* number of different labels */); | |
network.fit(dataIter, epochs); | |
} | |
static class CustomRecordReader extends BaseRecordReader { | |
int n = 0; | |
CustomRecordReader() { | |
} | |
@Override | |
public boolean batchesSupported() { | |
return false; | |
} | |
@Override | |
public List<List<Writable>> next(int num) { | |
throw new RuntimeException("Not implemented"); | |
} | |
@Override | |
public List<Writable> next() { | |
INDArray nd = Nd4j.create(new float[nZ*nY*nX], new int[] { nZ, nY, nX }, 'C'); | |
final List<Writable>res = RecordConverter.toRecord(nd); | |
res.add(new IntWritable(0)); | |
n++; | |
return res; | |
} | |
@Override | |
public boolean hasNext() { | |
return n<10; | |
} | |
final static ArrayList<String> labels = new ArrayList<>(2); | |
static { | |
labels.add("lbl0"); | |
labels.add("lbl1"); | |
} | |
@Override | |
public List<String> getLabels() { | |
return labels; | |
} | |
@Override | |
public void reset() { | |
n = 0; | |
} | |
@Override | |
public boolean resetSupported() { | |
return true; | |
} | |
@Override | |
public List<Writable> record(URI uri, DataInputStream dataInputStream) { | |
return next(); | |
} | |
@Override | |
public Record nextRecord() { | |
List<Writable> r = next(); | |
return new org.datavec.api.records.impl.Record(r, null); | |
} | |
@Override | |
public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException { | |
throw new RuntimeException("Not implemented"); | |
} | |
@Override | |
public List<Record> loadFromMetaData(List<RecordMetaData> recordMetaDatas) { | |
throw new RuntimeException("Not implemented"); | |
} | |
@Override | |
public void close() { | |
} | |
@Override | |
public void setConf(Configuration conf) { | |
} | |
@Override | |
public Configuration getConf() { | |
return null; | |
} | |
@Override | |
public void initialize(InputSplit split) { | |
n = 0; | |
} | |
@Override | |
public void initialize(Configuration conf, InputSplit split) { | |
n = 0; | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loaded [CpuBackend] backend | |
Number of threads used for NativeOps: 4 | |
Number of threads used for BLAS: 4 | |
Backend used: [CPU]; OS: [Linux] | |
Cores: [8]; Memory: [1,9GB]; | |
Blas vendor: [OPENBLAS] | |
Starting MultiLayerNetwork with WorkspaceModes set to [training: ENABLED; inference: ENABLED], cacheMode set to [NONE] | |
Steps: 4 | |
Exception in thread "ADSI prefetch thread" java.lang.RuntimeException: java.lang.RuntimeException: Error parsing data (writables) from record readers | |
at org.deeplearning4j.datasets.iterator.AsyncDataSetIterator$AsyncPrefetchThread.run(AsyncDataSetIterator.java:430) | |
Caused by: java.lang.RuntimeException: Error parsing data (writables) from record readers | |
at org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator.convertWritables(RecordReaderMultiDataSetIterator.java:459) | |
at org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator.convertFeaturesOrLabels(RecordReaderMultiDataSetIterator.java:363) | |
at org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator.nextMultiDataSet(RecordReaderMultiDataSetIterator.java:326) | |
at org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator.next(RecordReaderMultiDataSetIterator.java:212) | |
at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:364) | |
at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:439) | |
at org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator.next(RecordReaderDataSetIterator.java:84) | |
at org.deeplearning4j.datasets.iterator.AsyncDataSetIterator$AsyncPrefetchThread.run(AsyncDataSetIterator.java:404) | |
Caused by: org.nd4j.linalg.exception.ND4JIllegalStateException: X, Y and Z arguments should have the same length for PairwiseTransform. x: length 1024, shape [32, 32]; y: 28672, shape [28, 32, 32]; z: 1024, shape [32, 32] | |
at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.exec(NativeOpExecutioner.java:771) | |
at org.nd4j.linalg.cpu.nativecpu.ops.NativeOpExecutioner.exec(NativeOpExecutioner.java:113) | |
at org.nd4j.linalg.api.ndarray.BaseNDArray.assign(BaseNDArray.java:1390) | |
at org.nd4j.linalg.api.ndarray.BaseNDArray.put(BaseNDArray.java:2476) | |
at org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator.putExample(RecordReaderMultiDataSetIterator.java:548) | |
at org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator.convertWritablesHelper(RecordReaderMultiDataSetIterator.java:515) | |
at org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator.convertWritables(RecordReaderMultiDataSetIterator.java:453) | |
... 7 more |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment