Created
July 31, 2018 12:38
-
-
Save reuschling/ba771bb80b609b10cc3c328bdc489086 to your computer and use it in GitHub Desktop.
BatchNormalization: Gist for showing an issue with an 1D CNN on timeseries data that leads to an IllegalStateException: Invalid input type: Batch norm layer expected input of type CNN, CNN Flat or FF, got InputTypeRecurrent(128,timeSeriesLength=23) for layer index 1, layer name = layer1
We can make this file beautiful and searchable if this error is corrected: No commas found in this CSV file in line 0.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1 | |
1 | |
1 | |
1 | |
1 | |
1 | |
1 | |
1 | |
1 | |
1 | |
1 | |
1 | |
1 | |
1 | |
1 | |
1 | |
1 | |
1 | |
1 | |
1 | |
1 | |
1 | |
1 |
We can make this file beautiful and searchable if this error is corrected: No commas found in this CSV file in line 0.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package dfki.sds.cnn; | |
import java.io.File; | |
import java.util.List; | |
import org.datavec.api.records.reader.RecordReader; | |
import org.datavec.api.records.reader.SequenceRecordReader; | |
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; | |
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; | |
import org.datavec.api.split.NumberedFileInputSplit; | |
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator; | |
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator.Builder; | |
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator; | |
import org.deeplearning4j.eval.Evaluation; | |
import org.deeplearning4j.eval.RegressionEvaluation; | |
import org.deeplearning4j.nn.conf.ConvolutionMode; | |
import org.deeplearning4j.nn.conf.MultiLayerConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.ListBuilder; | |
import org.deeplearning4j.nn.conf.distribution.NormalDistribution; | |
import org.deeplearning4j.nn.conf.inputs.InputType; | |
import org.deeplearning4j.nn.conf.layers.ActivationLayer; | |
import org.deeplearning4j.nn.conf.layers.BatchNormalization; | |
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer; | |
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer; | |
import org.deeplearning4j.nn.conf.layers.OutputLayer; | |
import org.deeplearning4j.nn.conf.layers.PoolingType; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import org.nd4j.linalg.learning.config.AdaDelta; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
public class RRInputType4CNNGist | |
{ | |
static protected Logger log = LoggerFactory.getLogger(RRInputType4CNNGist.class); | |
public static void main(String[] args) throws Exception | |
{ | |
int miniBatchSize = 16; | |
int numLabelClasses = 8; | |
boolean regression = false; | |
// ----- Load the training data ----- | |
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(0, ";"); | |
trainFeatures.initialize(new NumberedFileInputSplit(new File(".").getAbsolutePath() + "/%d_feature.csv", 0, 0)); | |
RecordReader trainLabels = new CSVRecordReader(0, ';'); | |
trainLabels.initialize(new NumberedFileInputSplit("file://" + new File(".").getAbsolutePath() + "/%d_label.csv", 0, 0)); | |
DataSetIterator trainData = createTimeSeriesDataSetIteratorSeq2Classes(trainFeatures, trainLabels, miniBatchSize, numLabelClasses, regression); | |
// ----- Load the test data: for simplification here, just the train data again----- | |
// Same process as for the training data. | |
SequenceRecordReader testFeatures = new CSVSequenceRecordReader(0, ";"); | |
testFeatures.initialize(new NumberedFileInputSplit(new File(".").getAbsolutePath() + "/%d_feature.csv", 0, 0)); | |
RecordReader testLabels = new CSVRecordReader(0, ';'); | |
testLabels.initialize(new NumberedFileInputSplit("file://" + new File(".").getAbsolutePath() + "/%d_label.csv", 0, 0)); | |
DataSetIterator testData = createTimeSeriesDataSetIteratorSeq2Classes(testFeatures, testLabels, miniBatchSize, numLabelClasses, regression); | |
int length = 23; | |
int convNIn = 1; | |
// ----- Configure the network ----- | |
// @formatter:off | |
ListBuilder listBuilder = new NeuralNetConfiguration.Builder() | |
.updater(new AdaDelta()) | |
.weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0, 1)) | |
.convolutionMode(ConvolutionMode.Same) | |
.list(); | |
//'FCN' from https://arxiv.org/pdf/1611.06455.pdf | |
listBuilder.layer(new Convolution1DLayer.Builder().kernelSize(8).stride(1).nIn(convNIn).nOut(128).build()); | |
listBuilder.layer(new BatchNormalization.Builder().nIn(128).nOut(128).build()); | |
listBuilder.layer(new ActivationLayer.Builder().activation(Activation.RELU).build()); | |
listBuilder.layer(new Convolution1DLayer.Builder().kernelSize(5).stride(1).nIn(128).nOut(256).build()); | |
listBuilder.layer(new BatchNormalization.Builder().nIn(256).nOut(256).build()); | |
listBuilder.layer(new ActivationLayer.Builder().activation(Activation.RELU).build()); | |
listBuilder.layer(new Convolution1DLayer.Builder().kernelSize(3).stride(1).nIn(256).nOut(128).build()); | |
listBuilder.layer(new BatchNormalization.Builder().nIn(128).nOut(128).build()); | |
listBuilder.layer(new ActivationLayer.Builder().activation(Activation.RELU).build()); | |
listBuilder.layer(new GlobalPoolingLayer(PoolingType.AVG)); | |
if(regression) | |
listBuilder.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nOut(1).build()); | |
else | |
listBuilder.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(numLabelClasses).build()); | |
listBuilder.setInputType(InputType.recurrent(convNIn, length)); | |
MultiLayerConfiguration conf = listBuilder.build(); | |
// @formatter:on | |
MultiLayerNetwork net = new MultiLayerNetwork(conf); | |
net.init(); | |
net.setListeners(new ScoreIterationListener(20)); // Print the score (loss function value) every 20 iterations | |
int nEpochs = 30; | |
for (int i = 0; i < nEpochs; i++) | |
{ | |
net.fit(trainData); | |
if(regression) | |
{ | |
RegressionEvaluation evaluation = net.evaluateRegression(testData); | |
log.info("epoch " + i + ":\n" + evaluation.stats()); | |
} | |
else | |
{ | |
Evaluation evaluation = net.evaluate(testData); | |
log.info("epoch " + i + ":\n" + evaluation.stats()); | |
} | |
testData.reset(); | |
trainData.reset(); | |
} | |
log.info("----- Example Complete -----"); | |
} | |
public static DataSetIterator createTimeSeriesDataSetIteratorSeq2Classes(SequenceRecordReader trainFeatures, RecordReader trainLabels, int miniBatchSize, | |
int numLabelClasses, boolean regression) | |
{ | |
Builder trainDataMultiBuilder = | |
new RecordReaderMultiDataSetIterator.Builder(miniBatchSize).addSequenceReader("in", trainFeatures).addReader("out", trainLabels).addInput("in"); | |
if(!regression) | |
trainDataMultiBuilder.addOutputOneHot("out", 0, numLabelClasses).build(); | |
else | |
trainDataMultiBuilder.addOutput("out").build(); | |
trainDataMultiBuilder.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END); | |
RecordReaderMultiDataSetIterator trainDataMulti = trainDataMultiBuilder.build(); | |
// in MultiDataSetWrapperIterator everything is with 'throw new UnsupportedOperationException();' | |
DataSetIterator trainData = new MultiDataSetWrapperIterator(trainDataMulti) | |
{ | |
@Override | |
public List<String> getLabels() | |
{ | |
// from SequenceRecordReaderDataSetIterator we need it, otherwise the Evaluation class gets an UnsupportedOperationException | |
return null; | |
} | |
}; | |
return trainData; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment