Last active
June 30, 2016 15:46
-
-
Save ftarlao/9b15614023c251da5de0f56826b429bf to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
I create the iterator in this way: | |
private static DataSetIterator createDatasetIterator(int batchSize, int numBatches, List<MyData> DB) { | |
List<DataSet> datasetList = new ArrayList<>(); | |
for (int i = 0; i < numBatches; i++) { | |
datasetList.add(createDataset(batchSize, i * batchSize, DB)); //It simply takes different portions of my data | |
} | |
ListDataSetIterator listDataSetIterator = new ListDataSetIterator(datasetList,1); //I have also tried with 2,3,4,5.. nothing changed | |
while (listDataSetIterator.hasNext()) { | |
DataSet next = listDataSetIterator.next(); | |
int i = next.numExamples(); | |
}*/ | |
return listDataSetIterator; | |
} | |
The single DataSet is created in this way: | |
private static DataSet createDataset(int batchSize, int startIndex, List<MyData> examples) { | |
DataSet dataset = null; | |
List<INDArray> data = new ArrayList<>(); | |
List<INDArray> label = new ArrayList<>(); | |
for (int i = startIndex; i < Math.min(startIndex + batchSize, examples.size()); i++) { | |
MyData example = examples.get(i); | |
INDArray tempdata = convert(example.getVectorData()); //1row 40000 columns | |
INDArray templabel = example.isTheSameUser() ? Nd4j.create(new double[]{1.0, 0.0}, new int[]{1, 2}) : Nd4j.create(new double[]{0.0, 1.0}, new int[]{1, 2}); | |
//labes:1 row 2 cols | |
data.add(tempdata); | |
label.add(templabel); | |
} | |
INDArray dataFinal = Nd4j.vstack(data); | |
INDArray labelFinal = Nd4j.vstack(label); | |
dataset = new DataSet(dataFinal, labelFinal); | |
return dataset; | |
} | |
THis is how my model is trained: | |
model.setListeners(new ScoreIterationListener(1));//, new HistogramIterationListener(1)); | |
for (int i = 0; i < numEpochs; i++) { | |
System.out.println("Epoch " + i); | |
DataSetIterator train = createDatasetIterator(batchSize, numBatchesPerEpoch, trainUserList, loadTraceDB, random); | |
model.fit(train); | |
//one epoch have many batches, so this is computed rarely | |
log.info("Evaluate model...."); | |
Evaluation eval = new Evaluation(outputNum); | |
INDArray output = model.output(test.getFeatureMatrix(), false); | |
eval.eval(test.getLabels(), output); | |
log.info(eval.stats()); | |
log.info("TP " + eval.truePositives() | |
+ " FP " + eval.falsePositives() | |
+ " TN " + eval.trueNegatives() | |
+ " FN " + eval.falseNegatives()); | |
log.info("Accuracy " + eval.accuracy() | |
+ " Precision " + eval.precision() | |
+ " Recall " + eval.recall() | |
+ " F1 " + eval.f1()); | |
} | |
log.info("*** Completed training ***"); | |
log.info( | |
"****************Example finished********************"); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment