Created
November 11, 2016 01:14
-
-
Save tomthetrainer/5559183206338c67592ff55fd48b1c5d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package org.deeplearning4j.examples.dataExamples; | |
import org.datavec.api.io.labels.ParentPathLabelGenerator; | |
import org.datavec.api.split.FileSplit; | |
import org.datavec.image.loader.NativeImageLoader; | |
import org.datavec.image.recordreader.ImageRecordReader; | |
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; | |
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; | |
import org.deeplearning4j.util.ModelSerializer; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.dataset.api.DataSet; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization; | |
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import javax.swing.*; | |
import java.io.File; | |
import java.util.Arrays; | |
import java.util.HashMap; | |
import java.util.List; | |
import java.util.Random; | |
/** | |
* Created by tomhanlon on 11/7/16. | |
*/ | |
public class MnistImagePipelineExampleLoadAndChooseCleanUp { | |
private static Logger log = LoggerFactory.getLogger(MnistImagePipelineExampleLoadAndChooseCleanUp.class); | |
public static String fileChose() | |
{ | |
JFileChooser fc= new JFileChooser(); | |
int ret = fc.showOpenDialog(null); | |
if (ret== JFileChooser.APPROVE_OPTION) | |
{ | |
File file = fc.getSelectedFile(); | |
String filename= file.getAbsolutePath(); | |
return filename; | |
} | |
else | |
return null; | |
} | |
public static void main(String[] args) throws Exception { | |
// image information | |
// 28 * 28 grayscale | |
// grayscale implies single channel | |
// Black background white digit | |
// is what was trained on | |
int height = 28; | |
int width = 28; | |
int channels = 1; | |
int rngseed = 123; | |
Random randNumGen = new Random(rngseed); | |
int batchSize = 128; | |
int outputNum = 10; | |
// This is the file that | |
// the user choses from | |
// the popup chooser window | |
String filechose = fileChose().toString(); | |
// Define the File Paths | |
File testData = new File(filechose); | |
// Define the FileSplit(PATH, ALLOWED FORMATS,random) | |
// Build a list from the labels list | |
// [2, 3, 7, 1, 6, 4, 0, 5, 8] | |
// these are the digit labels as a list, | |
// retrieved from RecordReaderDataSetIterator.GetLabels | |
// when the model was trained | |
List<Integer> labelList = Arrays.asList(2, 3, 7, 1, 6, 4, 0, 5, 8); | |
FileSplit test = new FileSplit(testData,NativeImageLoader.ALLOWED_FORMATS); | |
// Extract the parent path as the image label | |
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); | |
ImageRecordReader recordReader = new ImageRecordReader(height,width,channels,labelMaker); | |
// Scale pixel values to 0-1 | |
DataNormalization scaler = new ImagePreProcessingScaler(0,1); | |
// Load Saved Neural Network | |
log.info("******LOAD TRAINED MODEL******"); | |
// Details | |
// Where Model was saved | |
File locationToSave = new File("trained_mnist_model.zip"); | |
// Restore the Model | |
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(locationToSave); | |
log.info("******Test your input image against model******"); | |
//recordReader.reset(); | |
recordReader.initialize(test); | |
DataSetIterator testIter = new RecordReaderDataSetIterator(recordReader,batchSize,1,outputNum); | |
scaler.fit(testIter); | |
testIter.setPreProcessor(scaler); | |
// Create Eval object with 10 possible classes | |
//Evaluation eval = new Evaluation(outputNum); | |
/* | |
If I wanted to zip up the two lists | |
for (int i = 0; i < strings.size(); i++) { | |
map.put(strings.get(i), integers.get(i)); | |
} | |
*/ | |
while(testIter.hasNext()){ | |
DataSet next = testIter.next(); | |
INDArray output = model.output(next.getFeatureMatrix()); | |
//log.info(next.toString()); | |
log.info("## THE FILE YOU CHOSE ## " + filechose); | |
log.info("## Neural Net's Prediction ##"); | |
log.info("## list of probabilities per label ##"); | |
log.info("## list of Labels in order ##"); | |
log.info(output.toString()); | |
log.info(labelList.toString()); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment