Skip to content

Instantly share code, notes, and snippets.

@crockpotveggies
Created June 6, 2019 03:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save crockpotveggies/83d4c9ee2aca7c39d8d77ece13935f0d to your computer and use it in GitHub Desktop.
Save crockpotveggies/83d4c9ee2aca7c39d8d77ece13935f0d to your computer and use it in GitHub Desktop.
import org.apache.commons.io.FileUtils;
import org.datavec.api.records.reader.SequenceRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVSequenceRecordReader;
import org.datavec.api.split.NumberedFileInputSplit;
import org.deeplearning4j.api.storage.StatsStorage;
import org.deeplearning4j.arbiter.ComputationGraphSpace;
import org.deeplearning4j.arbiter.conf.updater.SgdSpace;
import org.deeplearning4j.arbiter.layers.LSTMLayerSpace;
import org.deeplearning4j.arbiter.layers.OutputLayerSpace;
import org.deeplearning4j.arbiter.optimize.api.OptimizationResult;
import org.deeplearning4j.arbiter.optimize.api.ParameterSpace;
import org.deeplearning4j.arbiter.optimize.api.data.DataSource;
import org.deeplearning4j.arbiter.optimize.api.saving.ResultReference;
import org.deeplearning4j.arbiter.optimize.api.termination.MaxCandidatesCondition;
import org.deeplearning4j.arbiter.optimize.api.termination.MaxTimeCondition;
import org.deeplearning4j.arbiter.optimize.api.termination.TerminationCondition;
import org.deeplearning4j.arbiter.optimize.config.OptimizationConfiguration;
import org.deeplearning4j.arbiter.optimize.generator.RandomSearchGenerator;
import org.deeplearning4j.arbiter.optimize.parameter.continuous.ContinuousParameterSpace;
import org.deeplearning4j.arbiter.optimize.parameter.integer.IntegerParameterSpace;
import org.deeplearning4j.arbiter.optimize.runner.IOptimizationRunner;
import org.deeplearning4j.arbiter.optimize.runner.LocalOptimizationRunner;
import org.deeplearning4j.arbiter.saver.local.FileModelSaver;
import org.deeplearning4j.arbiter.scoring.impl.EvaluationScoreFunction;
import org.deeplearning4j.arbiter.task.ComputationGraphTaskCreator;
import org.deeplearning4j.arbiter.ui.listener.ArbiterStatusListener;
import org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.EarlyStoppingResult;
import org.deeplearning4j.earlystopping.saver.LocalFileGraphSaver;
import org.deeplearning4j.earlystopping.scorecalc.DataSetLossCalculator;
import org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition;
import org.deeplearning4j.earlystopping.termination.ScoreImprovementEpochTerminationCondition;
import org.deeplearning4j.earlystopping.trainer.EarlyStoppingGraphTrainer;
import org.deeplearning4j.earlystopping.trainer.IEarlyStoppingTrainer;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.CacheMode;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.rnn.LastTimeStepVertex;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.InvocationType;
import org.deeplearning4j.optimize.listeners.EvaluativeListener;
import org.deeplearning4j.optimize.listeners.PerformanceListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.storage.FileStatsStorage;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;
import java.nio.file.Paths;
import java.util.*;
import java.util.concurrent.TimeUnit;
public class ICUScenario {
private static final long SEED = 83;
private static final int NB_INPUTS = 86;
private static final int NB_EPOCHS = 10;
private static final double LEARNING_RATE = 0.005;
private static final int BATCH_SIZE = 32;
private static final int NUM_LABELS = 2;
private static final int NB_TRAIN_EXAMPLES = 2000;
private static final int NB_TEST_EXAMPLES = 800;
public static void main(String[] args) throws Exception {
// TODO Auto-generated method stub
FileUtils.deleteQuietly(new File("out"));
FileUtils.forceMkdir(new File("out"));
if (args.length > 0 && args[0].equals("O"))
process(optimize());
else
process(build());
}
public static void process(ComputationGraph graph) throws Exception {
ICUDataSource ds = new ICUDataSource();
DataSetIterator trainIter = (DataSetIterator) ds.trainData();
DataSetIterator testIter = (DataSetIterator) ds.testData();
DataSetIterator testIter2 = (DataSetIterator) ds.testData2();
graph.init();
graph.setListeners(new ScoreIterationListener(50),
new PerformanceListener.Builder()
.reportSample(true)
.reportScore(true)
.reportTime(true)
.reportETL(true)
.reportBatch(true)
.reportIteration(true)
.setFrequency(50).build(),
new EvaluativeListener(testIter, 1, InvocationType.EPOCH_END));
System.out.println(graph.summary());
EarlyStoppingConfiguration<ComputationGraph> eac = new EarlyStoppingConfiguration.Builder<ComputationGraph>()
.epochTerminationConditions(new ScoreImprovementEpochTerminationCondition(2))
.iterationTerminationConditions(new MaxTimeIterationTerminationCondition(5, TimeUnit.MINUTES))
.scoreCalculator(new DataSetLossCalculator(testIter2, true))
.evaluateEveryNEpochs(1)
.modelSaver(new LocalFileGraphSaver("out"))
.build();
IEarlyStoppingTrainer<ComputationGraph> trainer = new EarlyStoppingGraphTrainer(eac, graph, trainIter);
System.out.println("Training model....");
EarlyStoppingResult<ComputationGraph> result = trainer.fit();
System.out.println("Termination reason: " + result.getTerminationReason());
System.out.println("Termination details: " + result.getTerminationDetails());
System.out.println("Total epochs: " + result.getTotalEpochs());
System.out.println("Best epoch number: " + result.getBestModelEpoch());
System.out.println("Score at best epoch: " + result.getBestModelScore());
Map<Integer,Double> epochVsScore = result.getScoreVsEpoch();
List<Integer> list = new ArrayList<Integer>(epochVsScore.keySet());
Collections.sort(list);
System.out.println("Epoch\tScore");
for( Integer i : list){
System.out.println(i + "\t" + epochVsScore.get(i));
}
}
public static void process(ComputationGraphSpace space) throws Exception {
TerminationCondition[] terminationConditions = {
new MaxTimeCondition(96, TimeUnit.HOURS),
new MaxCandidatesCondition(6)
};
@SuppressWarnings("deprecation")
OptimizationConfiguration configuration = new OptimizationConfiguration.Builder()
.candidateGenerator(new RandomSearchGenerator(space, null))
.dataSource(ICUDataSource.class, new Properties())
.modelSaver(new FileModelSaver("out"))
.scoreFunction(new EvaluationScoreFunction(
org.deeplearning4j.eval.Evaluation.Metric.ACCURACY))
.terminationConditions(terminationConditions)
.build();
IOptimizationRunner runner = new LocalOptimizationRunner(configuration, new ComputationGraphTaskCreator());
StatsStorage ss = new FileStatsStorage(new File("out/icuStats.dl4j"));
runner.addListeners(new ArbiterStatusListener(ss));
UIServer.getInstance().attach(ss);
runner.execute();
String s = "Best score: " + runner.bestScore() + "\n" + "Index of model with best score: "
+ runner.bestScoreCandidateIndex() + "\n" + "Number of configurations evaluated: "
+ runner.numCandidatesCompleted() + "\n";
System.out.println(s);
int indexOfBestResult = runner.bestScoreCandidateIndex();
List<ResultReference> allResults = runner.getResults();
OptimizationResult bestResult = allResults.get(indexOfBestResult).getResult();
MultiLayerNetwork bestModel = (MultiLayerNetwork) bestResult.getResultReference().getResultModel();
System.out.println("\n\nConfiguration of best model:\n");
System.out.println(bestModel.getLayerWiseConfigurations());
UIServer.getInstance().stop();
}
public static ComputationGraph build() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.l2(0.01)
.cacheMode(CacheMode.HOST)
.weightInit(WeightInit.XAVIER)
.activation(Activation.TANH)
.updater(new Sgd(LEARNING_RATE))
.graphBuilder()
.addInputs("in")
.addLayer("lstm", new LSTM.Builder()
.nIn(NB_INPUTS).nOut(128).build(), "in")
.addVertex("lastStep", new LastTimeStepVertex("in"), "lstm")
.addLayer("out", new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.activation(Activation.SOFTMAX).nIn(128).nOut(NUM_LABELS).build(), "lastStep")
.setOutputs("out").build();
return new ComputationGraph(conf);
}
public static ComputationGraphSpace optimize() {
ParameterSpace<Double> l2Param = new ContinuousParameterSpace(0.0001, 0.09);
ParameterSpace<Double> learnParam = new ContinuousParameterSpace(0.001, 0.09);
ParameterSpace<Integer> layer2Param = new IntegerParameterSpace(8, 256);
return new ComputationGraphSpace.Builder()
.seed(SEED)
.l2(l2Param)
.weightInit(WeightInit.XAVIER)
.activation(Activation.TANH)
.updater(new SgdSpace(learnParam))
.numEpochs(NB_EPOCHS)
.addInputs("in")
.addLayer("lstm", new LSTMLayerSpace.Builder()
.nIn(NB_INPUTS).nOut(layer2Param).build(), "in")
.addVertex("lastStep", new LastTimeStepVertex("in"), "lstm")
.addLayer("out", new OutputLayerSpace.Builder()
.activation(Activation.SOFTMAX).
nIn(layer2Param).nOut(2).build(), "lastStep")
.setOutputs("out").build();
}
public static class ICUDataSource implements DataSource {
private static final long serialVersionUID = 1L;
private static final String ROOT = "data/physionet";
private SequenceRecordReaderDataSetIterator training = null;
private SequenceRecordReaderDataSetIterator testing1 = null;
private SequenceRecordReaderDataSetIterator testing2 = null;
public ICUDataSource() throws Exception {
String featuresPath = Paths.get(ROOT, "sequence", "%d.csv").toString();
String labelsPath = Paths.get(ROOT, "mortality", "%d.csv").toString();
{
SequenceRecordReader trainData = new CSVSequenceRecordReader(1, ",");
trainData.initialize( new NumberedFileInputSplit(
featuresPath, 0, NB_TRAIN_EXAMPLES - 1));
SequenceRecordReader trainLabels = new CSVSequenceRecordReader();
trainLabels.initialize(new NumberedFileInputSplit(
labelsPath, 0, NB_TRAIN_EXAMPLES - 1));
training = new SequenceRecordReaderDataSetIterator(trainData, trainLabels,
BATCH_SIZE, NUM_LABELS, false,
SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
training.setPreProcessor(new LastStepPreProcessor());
}
{
SequenceRecordReader testData = new CSVSequenceRecordReader(1, ",");
testData.initialize(new NumberedFileInputSplit(
featuresPath, NB_TRAIN_EXAMPLES, NB_TRAIN_EXAMPLES + NB_TEST_EXAMPLES));
SequenceRecordReader testLabels = new CSVSequenceRecordReader();
testLabels.initialize(new NumberedFileInputSplit(
labelsPath, NB_TRAIN_EXAMPLES, NB_TRAIN_EXAMPLES + NB_TEST_EXAMPLES));
testing1 = new SequenceRecordReaderDataSetIterator(testData, testLabels,
BATCH_SIZE, NUM_LABELS, false,
SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
testing1.setPreProcessor(new LastStepPreProcessor());
}
{
SequenceRecordReader testData = new CSVSequenceRecordReader(1, ",");
testData.initialize(new NumberedFileInputSplit(
featuresPath, NB_TRAIN_EXAMPLES, NB_TRAIN_EXAMPLES + NB_TEST_EXAMPLES));
SequenceRecordReader testLabels = new CSVSequenceRecordReader();
testLabels.initialize(new NumberedFileInputSplit(
labelsPath, NB_TRAIN_EXAMPLES, NB_TRAIN_EXAMPLES + NB_TEST_EXAMPLES));
testing2 = new SequenceRecordReaderDataSetIterator(testData, testLabels,
BATCH_SIZE, NUM_LABELS, false,
SequenceRecordReaderDataSetIterator.AlignmentMode.ALIGN_END);
testing2.setPreProcessor(new LastStepPreProcessor());
}
}
@Override
public void configure(Properties properties) {
// TODO Auto-generated method stub
}
@Override
public Class<?> getDataType() {
// TODO Auto-generated method stub
return DataSetIterator.class;
}
@Override
public Object testData() {
// TODO Auto-generated method stub
return testing1;
}
public Object testData2() {
return testing2;
}
@Override
public Object trainData() {
// TODO Auto-generated method stub
return training;
}
}
public static class LastStepPreProcessor implements DataSetPreProcessor {
private static final long serialVersionUID = 1L;
public LastStepPreProcessor() {}
@Override
public void preProcess(DataSet toPreProcess) {
// TODO Auto-generated method stub
INDArray labels = toPreProcess.getLabels();
INDArray mask = toPreProcess.getLabelsMaskArray();
INDArray labels2d = pullLastTimeSteps(labels, mask);
toPreProcess.setLabels(labels2d);
toPreProcess.setLabelsMaskArray(null);
}
private INDArray pullLastTimeSteps(INDArray pullFrom, INDArray mask) {
INDArray out = null;
if (mask == null) {
//No mask array -> extract same (last) column for all
long lastTS = pullFrom.size(2) - 1;
out = pullFrom.get(NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(lastTS));
} else {
out = Nd4j.zeros(pullFrom.size(0), pullFrom.size(1));
//Want the index of the last non-zero entry in the mask array
INDArray lastStepArr = BooleanIndexing.lastIndex(mask, Conditions.epsNotEquals(0.0), 1);
int [] fwdPassTimeSteps = lastStepArr.data().asInt();
for (int i = 0; i < fwdPassTimeSteps.length; i++) {
out.putRow(i, pullFrom.get(NDArrayIndex.point(i), NDArrayIndex.all(),
NDArrayIndex.point(fwdPassTimeSteps[i])));
}
}
return out;
}
}
}
@zollen
Copy link

zollen commented Jun 6, 2019

Thanks for your help. I am testing it as we speak.

@zollen
Copy link

zollen commented Jun 6, 2019

You are awesome!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment