Skip to content

Instantly share code, notes, and snippets.

@yptheangel
Last active August 7, 2019 05:37
Show Gist options
  • Save yptheangel/65f8a99da6076a34f2d8b5f66354595c to your computer and use it in GitHub Desktop.
Save yptheangel/65f8a99da6076a34f2d8b5f66354595c to your computer and use it in GitHub Desktop.
Attempt to fix the error given input features and labels as a SequenceRecordReader and having OutputLayer as the last layer
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.NumberedFileInputSplit;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.datasets.datavec.RecordReaderMultiDataSetIterator;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
import org.deeplearning4j.nn.conf.graph.ReshapeVertex;
import org.deeplearning4j.nn.conf.layers.*;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.File;
public class ECG_CNNClassifier {
private static final Logger log = LoggerFactory.getLogger(ECG_CNNClassifier.class);
private static long seed = 123;
private static ComputationGraph model = null;
private static String modelFilename = new File(".").getAbsolutePath() + "/ECGCNN_Classifier.zip";
private static final int numClasses =5;
private static final int numSkipLines=0;
// Hyperparameters
private static int epochs =25;
private static int batchSize = 500 ;
private static final double learningRate = 0.00001;
private static boolean dontretrain = true;
public static void main(String[] args) throws Exception{
//Load data
File trainBaseDir = new File("C:\\Users\\choowilson\\Downloads\\ecg\\train_tmp");
File trainFeaturesDir = new File(trainBaseDir,"features");
File trainLabelsDir = new File(trainBaseDir,"labels");
File testBaseDir = new File("C:\\Users\\choowilson\\Downloads\\ecg\\test");
File testFeaturesDir = new File(testBaseDir,"features");
File testLabelsDir = new File(testBaseDir,"labels");
SequenceRecordReader trainFeatures = new CSVSequenceRecordReader(numSkipLines,",");
trainFeatures.initialize(new NumberedFileInputSplit(trainFeaturesDir.getAbsolutePath()+"/%d.csv",66471,87553));
// SequenceRecordReader trainLabels = new CSVSequenceRecordReader(numSkipLines,",");
// trainLabels.initialize(new NumberedFileInputSplit(trainLabelsDir.getAbsolutePath()+"/%d.csv",66471,87553));
// DataSetIterator train = new SequenceRecordReaderDataSetIterator(trainFeatures,trainLabels,batchSize,numClasses,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
RecordReader trainLabels = new CSVRecordReader();
trainLabels.initialize(new NumberedFileInputSplit(trainLabelsDir.getAbsolutePath()+"/%d.csv",66471,87553));
MultiDataSetIterator train = new RecordReaderMultiDataSetIterator.Builder(batchSize)
.addSequenceReader("trainFeatures",trainFeatures)
.addReader("trainLabels",trainLabels)
.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END)
.addInput("trainFeatures")
.addOutputOneHot("trainLabels",0,numClasses)
// .addOutput("trainLabels")
.build();
// DataNormalization normalizer = new NormalizerStandardize();
// normalizer.fit(train);
// train.reset();
// train.setPreProcessor(normalizer);
SequenceRecordReader testFeatures = new CSVSequenceRecordReader(numSkipLines,",");
testFeatures.initialize(new NumberedFileInputSplit(testFeaturesDir.getAbsolutePath()+"/%d.csv",0,21891));
SequenceRecordReader testLabels = new CSVSequenceRecordReader(numSkipLines,",");
testLabels.initialize(new NumberedFileInputSplit(testLabelsDir.getAbsolutePath()+"/%d.csv",0,21891));
DataSetIterator test = new SequenceRecordReaderDataSetIterator(testFeatures,testLabels,batchSize,numClasses,false, SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
// test.setPreProcessor(normalizer);
if (new File(modelFilename).exists() && dontretrain == false )
{
model = ModelSerializer.restoreComputationGraph(modelFilename);
}
else {
//Init Model
// log.info("Number of columns: {}",train.inputColumns());
// int numInput = train.inputColumns();
ComputationGraphConfiguration config = new NeuralNetConfiguration.Builder()
.seed(seed)
.weightInit(WeightInit.XAVIER)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.updater(new Adam(learningRate))
.graphBuilder()
.addInputs("trainFeatures")
.setOutputs("A7")
// .addLayer("C", new Convolution1DLayer.Builder(5,1)
.addLayer("C", new Convolution1DLayer.Builder(5,1)
.nIn(1)
// .nIn(numInput)
.nOut(32)
.activation(Activation.RELU)
.build(),"trainFeatures")
// .addLayer("P11",new ZeroPadding1DLayer.Builder(2).build(),"C" )
.addLayer("C11", new Convolution1DLayer.Builder(5,1)
.nIn(32)
.nOut(32)
.activation(Activation.RELU)
.build(),"C")
// .addLayer("P12",new ZeroPadding1DLayer.Builder(2).build(),"C11" )
.addLayer("C12", new Convolution1DLayer.Builder(5,1)
.nIn(32)
.nOut(32)
.activation(Activation.RELU)
.build(),"C11")
// .addVertex("S11", new ElementWiseVertex(ElementWiseVertex.Op.Add),"C12","C")
// .addLayer("A11",new Convolution1DLayer.Builder(5,1)
// .nIn(32)
// .nOut(32)
// .activation(Activation.RELU)
// .build(),"S11")
// .addLayer("M11",new Subsampling1DLayer.Builder(5,2)
//// .build(),"A11")
//// .build(),"S11")
// .build(),"C12")
.addVertex("Reshape", new ReshapeVertex(-1,32*175),"C12")
//Comment this out from here to
.addLayer("D1",new DenseLayer.Builder()
.nIn(32*175)
.nOut(32)
.build(),"Reshape")
//here
//comment this out temporarily
// .addLayer("D2",new DenseLayer.Builder()
// .nIn(32)
// .nOut(32)
// .activation(Activation.RELU)
// .build()
// ,"D1")
// .addLayer("D3",new DenseLayer.Builder()
// .nIn(32)
// .nOut(5)
// .activation(Activation.RELU)
// .build(),"D2")
.addLayer("A7",new OutputLayer.Builder()
.nIn(32)
.nOut(numClasses)
.lossFunction(LossFunctions.LossFunction.MCXENT)
.activation(Activation.SOFTMAX)
.build(),"D1")
// .inputPreProcessor("D1",new CnnToFeedForwardPreProcessor())
// .inputPreProcessor("D1",new RnnToFeedForwardPreProcessor())
.allowDisconnected(true)
.build();
model = new ComputationGraph(config);
model.init();
log.info(model.summary());
//Setup Training GUI
UIServer server = UIServer.getInstance();
StatsStorage storage = new InMemoryStatsStorage();
server.attach(storage);
// model.setListeners(new ScoreIterationListener(), new StatsListener(storage), new TimeIterationListener(10) , new PerformanceListener(10));
// model.setListeners(new StatsListener(storage), new ScoreIterationListener(), new PerformanceListener(10));
model.setListeners(new StatsListener(storage), new ScoreIterationListener());
//Train Model
for (int i = 1; i < epochs + 1; i++) {
model.fit(train);
train.reset();
log.info("Completed epoch {}", i);
}
ModelSerializer.writeModel(model, modelFilename, true);
log.info("Model saved at {} - Done", modelFilename);
}
//Evaluate the model against test dataset(unseen data)
Evaluation eval = model.evaluate(test);
log.info(eval.stats());
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment