Skip to content

Instantly share code, notes, and snippets.

@ftarlao
Last active June 30, 2016 15:46
Show Gist options
  • Save ftarlao/9b15614023c251da5de0f56826b429bf to your computer and use it in GitHub Desktop.
Save ftarlao/9b15614023c251da5de0f56826b429bf to your computer and use it in GitHub Desktop.
I create the iterator in this way:
private static DataSetIterator createDatasetIterator(int batchSize, int numBatches, List<MyData> DB) {
List<DataSet> datasetList = new ArrayList<>();
for (int i = 0; i < numBatches; i++) {
datasetList.add(createDataset(batchSize, i * batchSize, DB)); //It simply takes different portions of my data
}
ListDataSetIterator listDataSetIterator = new ListDataSetIterator(datasetList,1); //I have also tried with 2,3,4,5.. nothing changed
while (listDataSetIterator.hasNext()) {
DataSet next = listDataSetIterator.next();
int i = next.numExamples();
}*/
return listDataSetIterator;
}
The single DataSet is created in this way:
private static DataSet createDataset(int batchSize, int startIndex, List<MyData> examples) {
DataSet dataset = null;
List<INDArray> data = new ArrayList<>();
List<INDArray> label = new ArrayList<>();
for (int i = startIndex; i < Math.min(startIndex + batchSize, examples.size()); i++) {
MyData example = examples.get(i);
INDArray tempdata = convert(example.getVectorData()); //1row 40000 columns
INDArray templabel = example.isTheSameUser() ? Nd4j.create(new double[]{1.0, 0.0}, new int[]{1, 2}) : Nd4j.create(new double[]{0.0, 1.0}, new int[]{1, 2});
//labes:1 row 2 cols
data.add(tempdata);
label.add(templabel);
}
INDArray dataFinal = Nd4j.vstack(data);
INDArray labelFinal = Nd4j.vstack(label);
dataset = new DataSet(dataFinal, labelFinal);
return dataset;
}
THis is how my model is trained:
model.setListeners(new ScoreIterationListener(1));//, new HistogramIterationListener(1));
for (int i = 0; i < numEpochs; i++) {
System.out.println("Epoch " + i);
DataSetIterator train = createDatasetIterator(batchSize, numBatchesPerEpoch, trainUserList, loadTraceDB, random);
model.fit(train);
//one epoch have many batches, so this is computed rarely
log.info("Evaluate model....");
Evaluation eval = new Evaluation(outputNum);
INDArray output = model.output(test.getFeatureMatrix(), false);
eval.eval(test.getLabels(), output);
log.info(eval.stats());
log.info("TP " + eval.truePositives()
+ " FP " + eval.falsePositives()
+ " TN " + eval.trueNegatives()
+ " FN " + eval.falseNegatives());
log.info("Accuracy " + eval.accuracy()
+ " Precision " + eval.precision()
+ " Recall " + eval.recall()
+ " F1 " + eval.f1());
}
log.info("*** Completed training ***");
log.info(
"****************Example finished********************");
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment