Created
November 28, 2016 20:31
-
-
Save tomthetrainer/f6e073444286e5d97d976bd77292a064 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.examples.dataExamples; | |
import org.datavec.api.records.reader.RecordReader; | |
import org.datavec.api.records.reader.impl.csv.CSVRecordReader; | |
import org.datavec.api.split.FileSplit; | |
import org.datavec.api.util.ClassPathResource; | |
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; | |
import org.deeplearning4j.eval.Evaluation; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.DataSet; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
/** | |
* Created by tomhanlon on 11/28/16. | |
*/ | |
public class ImportKerasConfig { | |
private static Logger log = LoggerFactory.getLogger(ImportKerasConfig.class); | |
public static void main(String[] args) throws Exception { | |
MultiLayerNetwork model = | |
org.deeplearning4j.nn.modelimport.keras.Model.importSequentialModel | |
("/Users/tomhanlon/tensorflow/video/iris_model_json","/Users/tomhanlon/tensorflow/video/iris_model_save"); | |
int numLinesToSkip = 0; | |
String delimiter = ","; | |
// Read the iris.txt file as a collection of records | |
RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter); | |
recordReader.initialize(new FileSplit(new ClassPathResource("iris.txt").getFile())); | |
// label index | |
int labelIndex = 4; | |
// num of classes | |
int numClasses = 3; | |
// batchsize all | |
int batchSize = 150; | |
DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses); | |
DataSet allData = iterator.next(); | |
allData.shuffle(); | |
// Have our model | |
//we have our Record Reader to read data | |
// Evaluate the model | |
Evaluation eval = new Evaluation(3); | |
INDArray output = model.output(allData.getFeatureMatrix()); | |
eval.eval(allData.getLabels(),output); | |
log.info(eval.stats()); | |
} | |
} |
i use this code , but the accurancy is only 33.33% . It can only identify the class 1 . Any one can help me ? thanks
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This example uses a deprecated method (
importSequentialModel
).If anyone arrives here, follow this for a more recent example:
https://gist.github.com/turambar/be0a96a02fd1ba3f6010bc0d17fc90dc