Skip to content

Instantly share code, notes, and snippets.

@reuschling
Created July 31, 2018 12:38
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save reuschling/ba771bb80b609b10cc3c328bdc489086 to your computer and use it in GitHub Desktop.
Save reuschling/ba771bb80b609b10cc3c328bdc489086 to your computer and use it in GitHub Desktop.
BatchNormalization: Gist for showing an issue with an 1D CNN on timeseries data that leads to an IllegalStateException: Invalid input type: Batch norm layer expected input of type CNN, CNN Flat or FF, got InputTypeRecurrent(128,timeSeriesLength=23) for layer index 1, layer name = layer1
We can make this file beautiful and searchable if this error is corrected: No commas found in this CSV file in line 0.
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
1
We can make this file beautiful and searchable if this error is corrected: No commas found in this CSV file in line 0.
0
package dfki.sds.cnn;
import java.io.File;
import java.util.List;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.NumberedFileInputSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator;
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator.Builder;
import org.deeplearning4j.datasets.iterator.MultiDataSetWrapperIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.eval.RegressionEvaluation;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration.ListBuilder;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.Convolution1DLayer;
import org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.AdaDelta;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class RRInputType4CNNGist
{
static protected Logger log = LoggerFactory.getLogger(RRInputType4CNNGist.class);
public static void main(String[] args) throws Exception
{
int miniBatchSize = 16;
int numLabelClasses = 8;
boolean regression = false;
// ----- Load the training data -----
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(0, ";");
trainFeatures.initialize(new NumberedFileInputSplit(new File(".").getAbsolutePath() + "/%d_feature.csv", 0, 0));
RecordReader trainLabels = new CSVRecordReader(0, ';');
trainLabels.initialize(new NumberedFileInputSplit("file://" + new File(".").getAbsolutePath() + "/%d_label.csv", 0, 0));
DataSetIterator trainData = createTimeSeriesDataSetIteratorSeq2Classes(trainFeatures, trainLabels, miniBatchSize, numLabelClasses, regression);
// ----- Load the test data: for simplification here, just the train data again-----
// Same process as for the training data.
SequenceRecordReader testFeatures = new CSVSequenceRecordReader(0, ";");
testFeatures.initialize(new NumberedFileInputSplit(new File(".").getAbsolutePath() + "/%d_feature.csv", 0, 0));
RecordReader testLabels = new CSVRecordReader(0, ';');
testLabels.initialize(new NumberedFileInputSplit("file://" + new File(".").getAbsolutePath() + "/%d_label.csv", 0, 0));
DataSetIterator testData = createTimeSeriesDataSetIteratorSeq2Classes(testFeatures, testLabels, miniBatchSize, numLabelClasses, regression);
int length = 23;
int convNIn = 1;
// ----- Configure the network -----
// @formatter:off
ListBuilder listBuilder = new NeuralNetConfiguration.Builder()
.updater(new AdaDelta())
.weightInit(WeightInit.DISTRIBUTION).dist(new NormalDistribution(0, 1))
.convolutionMode(ConvolutionMode.Same)
.list();
//'FCN' from https://arxiv.org/pdf/1611.06455.pdf
listBuilder.layer(new Convolution1DLayer.Builder().kernelSize(8).stride(1).nIn(convNIn).nOut(128).build());
listBuilder.layer(new BatchNormalization.Builder().nIn(128).nOut(128).build());
listBuilder.layer(new ActivationLayer.Builder().activation(Activation.RELU).build());
listBuilder.layer(new Convolution1DLayer.Builder().kernelSize(5).stride(1).nIn(128).nOut(256).build());
listBuilder.layer(new BatchNormalization.Builder().nIn(256).nOut(256).build());
listBuilder.layer(new ActivationLayer.Builder().activation(Activation.RELU).build());
listBuilder.layer(new Convolution1DLayer.Builder().kernelSize(3).stride(1).nIn(256).nOut(128).build());
listBuilder.layer(new BatchNormalization.Builder().nIn(128).nOut(128).build());
listBuilder.layer(new ActivationLayer.Builder().activation(Activation.RELU).build());
listBuilder.layer(new GlobalPoolingLayer(PoolingType.AVG));
if(regression)
listBuilder.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nOut(1).build());
else
listBuilder.layer(new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nOut(numLabelClasses).build());
listBuilder.setInputType(InputType.recurrent(convNIn, length));
MultiLayerConfiguration conf = listBuilder.build();
// @formatter:on
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(20)); // Print the score (loss function value) every 20 iterations
int nEpochs = 30;
for (int i = 0; i < nEpochs; i++)
{
net.fit(trainData);
if(regression)
{
RegressionEvaluation evaluation = net.evaluateRegression(testData);
log.info("epoch " + i + ":\n" + evaluation.stats());
}
else
{
Evaluation evaluation = net.evaluate(testData);
log.info("epoch " + i + ":\n" + evaluation.stats());
}
testData.reset();
trainData.reset();
}
log.info("----- Example Complete -----");
}
public static DataSetIterator createTimeSeriesDataSetIteratorSeq2Classes(SequenceRecordReader trainFeatures, RecordReader trainLabels, int miniBatchSize,
int numLabelClasses, boolean regression)
{
Builder trainDataMultiBuilder =
new RecordReaderMultiDataSetIterator.Builder(miniBatchSize).addSequenceReader("in", trainFeatures).addReader("out", trainLabels).addInput("in");
if(!regression)
trainDataMultiBuilder.addOutputOneHot("out", 0, numLabelClasses).build();
else
trainDataMultiBuilder.addOutput("out").build();
trainDataMultiBuilder.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END);
RecordReaderMultiDataSetIterator trainDataMulti = trainDataMultiBuilder.build();
// in MultiDataSetWrapperIterator everything is with 'throw new UnsupportedOperationException();'
DataSetIterator trainData = new MultiDataSetWrapperIterator(trainDataMulti)
{
@Override
public List<String> getLabels()
{
// from SequenceRecordReaderDataSetIterator we need it, otherwise the Evaluation class gets an UnsupportedOperationException
return null;
}
};
return trainData;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment