Skip to content

Instantly share code, notes, and snippets.

@tomthetrainer
Created November 28, 2016 20:31
Show Gist options
  • Save tomthetrainer/f6e073444286e5d97d976bd77292a064 to your computer and use it in GitHub Desktop.
Save tomthetrainer/f6e073444286e5d97d976bd77292a064 to your computer and use it in GitHub Desktop.
package org.deeplearning4j.examples.dataExamples;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.util.ClassPathResource;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Created by tomhanlon on 11/28/16.
*/
public class ImportKerasConfig {
private static Logger log = LoggerFactory.getLogger(ImportKerasConfig.class);
public static void main(String[] args) throws Exception {
MultiLayerNetwork model =
org.deeplearning4j.nn.modelimport.keras.Model.importSequentialModel
("/Users/tomhanlon/tensorflow/video/iris_model_json","/Users/tomhanlon/tensorflow/video/iris_model_save");
int numLinesToSkip = 0;
String delimiter = ",";
// Read the iris.txt file as a collection of records
RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
recordReader.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile()));
// label index
int labelIndex = 4;
// num of classes
int numClasses = 3;
// batchsize all
int batchSize = 150;
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
DataSet allData = iterator.next();
allData.shuffle();
// Have our model
//we have our Record Reader to read data
// Evaluate the model
Evaluation eval = new Evaluation(3);
INDArray output = model.output(allData.getFeatureMatrix());
eval.eval(allData.getLabels(),output);
log.info(eval.stats());
}
}
@danimateos
Copy link

This example uses a deprecated method (importSequentialModel).

If anyone arrives here, follow this for a more recent example:

https://gist.github.com/turambar/be0a96a02fd1ba3f6010bc0d17fc90dc

@silencemao
Copy link

i use this code , but the accurancy is only 33.33% . It can only identify the class 1 . Any one can help me ? thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment