Last active
August 7, 2019 05:37
-
-
Save yptheangel/65f8a99da6076a34f2d8b5f66354595c to your computer and use it in GitHub Desktop.
Attempt to fix the error given input features and labels as a SequenceRecordReader and having OutputLayer as the last layer
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.datavec.api.records.reader.RecordReader; | |
import org.datavec.api.records.reader.SequenceRecordReader; | |
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; | |
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader; | |
import org.datavec.api.split.NumberedFileInputSplit; | |
import org.deeplearning4j.api.storage.StatsStorage; | |
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator; | |
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator; | |
import org.deeplearning4j.nn.api.OptimizationAlgorithm; | |
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex; | |
import org.deeplearning4j.nn.conf.graph.ReshapeVertex; | |
import org.deeplearning4j.nn.conf.layers.*; | |
import org.deeplearning4j.nn.graph.ComputationGraph; | |
import org.deeplearning4j.nn.weights.WeightInit; | |
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; | |
import org.deeplearning4j.ui.api.UIServer; | |
import org.deeplearning4j.ui.stats.StatsListener; | |
import org.deeplearning4j.ui.storage.InMemoryStatsStorage; | |
import org.deeplearning4j.util.ModelSerializer; | |
import org.nd4j.evaluation.classification.Evaluation; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; | |
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; | |
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; | |
import org.nd4j.linalg.learning.config.Adam; | |
import org.nd4j.linalg.lossfunctions.LossFunctions; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import java.io.File; | |
public class ECG_CNNClassifier { | |
private static final Logger log = LoggerFactory.getLogger(ECG_CNNClassifier.class); | |
private static long seed = 123; | |
private static ComputationGraph model = null; | |
private static String modelFilename = new File(".").getAbsolutePath() + "/ECGCNN_Classifier.zip"; | |
private static final int numClasses =5; | |
private static final int numSkipLines=0; | |
// Hyperparameters | |
private static int epochs =25; | |
private static int batchSize = 500 ; | |
private static final double learningRate = 0.00001; | |
private static boolean dontretrain = true; | |
public static void main(String[] args) throws Exception{ | |
//Load data | |
File trainBaseDir = new File("C:\\Users\\choowilson\\Downloads\\ecg\\train_tmp"); | |
File trainFeaturesDir = new File(trainBaseDir,"features"); | |
File trainLabelsDir = new File(trainBaseDir,"labels"); | |
File testBaseDir = new File("C:\\Users\\choowilson\\Downloads\\ecg\\test"); | |
File testFeaturesDir = new File(testBaseDir,"features"); | |
File testLabelsDir = new File(testBaseDir,"labels"); | |
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(numSkipLines,","); | |
trainFeatures.initialize(new NumberedFileInputSplit(trainFeaturesDir.getAbsolutePath()+"/%d.csv",66471,87553)); | |
// SequenceRecordReader trainLabels = new CSVSequenceRecordReader(numSkipLines,","); | |
// trainLabels.initialize(new NumberedFileInputSplit(trainLabelsDir.getAbsolutePath()+"/%d.csv",66471,87553)); | |
// DataSetIterator train = new SequenceRecordReaderDataSetIterator(trainFeatures,trainLabels,batchSize,numClasses,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); | |
RecordReader trainLabels = new CSVRecordReader(); | |
trainLabels.initialize(new NumberedFileInputSplit(trainLabelsDir.getAbsolutePath()+"/%d.csv",66471,87553)); | |
MultiDataSetIterator train = new RecordReaderMultiDataSetIterator.Builder(batchSize) | |
.addSequenceReader("trainFeatures",trainFeatures) | |
.addReader("trainLabels",trainLabels) | |
.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END) | |
.addInput("trainFeatures") | |
.addOutputOneHot("trainLabels",0,numClasses) | |
// .addOutput("trainLabels") | |
.build(); | |
// DataNormalization normalizer = new NormalizerStandardize(); | |
// normalizer.fit(train); | |
// train.reset(); | |
// train.setPreProcessor(normalizer); | |
SequenceRecordReader testFeatures = new CSVSequenceRecordReader(numSkipLines,","); | |
testFeatures.initialize(new NumberedFileInputSplit(testFeaturesDir.getAbsolutePath()+"/%d.csv",0,21891)); | |
SequenceRecordReader testLabels = new CSVSequenceRecordReader(numSkipLines,","); | |
testLabels.initialize(new NumberedFileInputSplit(testLabelsDir.getAbsolutePath()+"/%d.csv",0,21891)); | |
DataSetIterator test = new SequenceRecordReaderDataSetIterator(testFeatures,testLabels,batchSize,numClasses,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END); | |
// test.setPreProcessor(normalizer); | |
if (new File(modelFilename).exists() && dontretrain == false ) | |
{ | |
model = ModelSerializer.restoreComputationGraph(modelFilename); | |
} | |
else { | |
//Init Model | |
// log.info("Number of columns: {}",train.inputColumns()); | |
// int numInput = train.inputColumns(); | |
ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder() | |
.seed(seed) | |
.weightInit(WeightInit.XAVIER) | |
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
.updater(new Adam(learningRate)) | |
.graphBuilder() | |
.addInputs("trainFeatures") | |
.setOutputs("A7") | |
// .addLayer("C", new Convolution1DLayer.Builder(5,1) | |
.addLayer("C", new Convolution1DLayer.Builder(5,1) | |
.nIn(1) | |
// .nIn(numInput) | |
.nOut(32) | |
.activation(Activation.RELU) | |
.build(),"trainFeatures") | |
// .addLayer("P11",new ZeroPadding1DLayer.Builder(2).build(),"C" ) | |
.addLayer("C11", new Convolution1DLayer.Builder(5,1) | |
.nIn(32) | |
.nOut(32) | |
.activation(Activation.RELU) | |
.build(),"C") | |
// .addLayer("P12",new ZeroPadding1DLayer.Builder(2).build(),"C11" ) | |
.addLayer("C12", new Convolution1DLayer.Builder(5,1) | |
.nIn(32) | |
.nOut(32) | |
.activation(Activation.RELU) | |
.build(),"C11") | |
// .addVertex("S11", new ElementWiseVertex(ElementWiseVertex.Op.Add),"C12","C") | |
// .addLayer("A11",new Convolution1DLayer.Builder(5,1) | |
// .nIn(32) | |
// .nOut(32) | |
// .activation(Activation.RELU) | |
// .build(),"S11") | |
// .addLayer("M11",new Subsampling1DLayer.Builder(5,2) | |
//// .build(),"A11") | |
//// .build(),"S11") | |
// .build(),"C12") | |
.addVertex("Reshape", new ReshapeVertex(-1,32*175),"C12") | |
//Comment this out from here to | |
.addLayer("D1",new DenseLayer.Builder() | |
.nIn(32*175) | |
.nOut(32) | |
.build(),"Reshape") | |
//here | |
//comment this out temporarily | |
// .addLayer("D2",new DenseLayer.Builder() | |
// .nIn(32) | |
// .nOut(32) | |
// .activation(Activation.RELU) | |
// .build() | |
// ,"D1") | |
// .addLayer("D3",new DenseLayer.Builder() | |
// .nIn(32) | |
// .nOut(5) | |
// .activation(Activation.RELU) | |
// .build(),"D2") | |
.addLayer("A7",new OutputLayer.Builder() | |
.nIn(32) | |
.nOut(numClasses) | |
.lossFunction(LossFunctions.LossFunction.MCXENT) | |
.activation(Activation.SOFTMAX) | |
.build(),"D1") | |
// .inputPreProcessor("D1",new CnnToFeedForwardPreProcessor()) | |
// .inputPreProcessor("D1",new RnnToFeedForwardPreProcessor()) | |
.allowDisconnected(true) | |
.build(); | |
model = new ComputationGraph(config); | |
model.init(); | |
log.info(model.summary()); | |
//Setup Training GUI | |
UIServer server = UIServer.getInstance(); | |
StatsStorage storage = new InMemoryStatsStorage(); | |
server.attach(storage); | |
// model.setListeners(new ScoreIterationListener(), new StatsListener(storage), new TimeIterationListener(10) , new PerformanceListener(10)); | |
// model.setListeners(new StatsListener(storage), new ScoreIterationListener(), new PerformanceListener(10)); | |
model.setListeners(new StatsListener(storage), new ScoreIterationListener()); | |
//Train Model | |
for (int i = 1; i < epochs + 1; i++) { | |
model.fit(train); | |
train.reset(); | |
log.info("Completed epoch {}", i); | |
} | |
ModelSerializer.writeModel(model, modelFilename, true); | |
log.info("Model saved at {} - Done", modelFilename); | |
} | |
//Evaluate the model against test dataset(unseen data) | |
Evaluation eval = model.evaluate(test); | |
log.info(eval.stats()); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment