Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tomthetrainer/5559183206338c67592ff55fd48b1c5d to your computer and use it in GitHub Desktop.
Save tomthetrainer/5559183206338c67592ff55fd48b1c5d to your computer and use it in GitHub Desktop.
package org.deeplearning4j.examples.dataExamples;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.swing.*;
import java.io.File;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Random;
/**
* Created by tomhanlon on 11/7/16.
*/
public class MnistImagePipelineExampleLoadAndChooseCleanUp {
private static Logger log = LoggerFactory.getLogger(MnistImagePipelineExampleLoadAndChooseCleanUp.class);
public static String fileChose()
{
JFileChooser fc= new JFileChooser();
int ret = fc.showOpenDialog(null);
if (ret== JFileChooser.APPROVE_OPTION)
{
File file = fc.getSelectedFile();
String filename= file.getAbsolutePath();
return filename;
}
else
return null;
}
public static void main(String[] args) throws Exception {
// image information
// 28 * 28 grayscale
// grayscale implies single channel
// Black background white digit
// is what was trained on
int height = 28;
int width = 28;
int channels = 1;
int rngseed = 123;
Random randNumGen = new Random(rngseed);
int batchSize = 128;
int outputNum = 10;
// This is the file that
// the user choses from
// the popup chooser window
String filechose = fileChose().toString();
// Define the File Paths
File testData = new File(filechose);
// Define the FileSplit(PATH, ALLOWED FORMATS,random)
// Build a list from the labels list
// [2, 3, 7, 1, 6, 4, 0, 5, 8]
// these are the digit labels as a list,
// retrieved from RecordReaderDataSetIterator.GetLabels
// when the model was trained
List<Integer> labelList = Arrays.asList(2, 3, 7, 1, 6, 4, 0, 5, 8);
FileSplit test = new FileSplit(testData,NativeImageLoader.ALLOWED_FORMATS);
// Extract the parent path as the image label
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
ImageRecordReader recordReader = new ImageRecordReader(height,width,channels,labelMaker);
// Scale pixel values to 0-1
DataNormalization scaler = new ImagePreProcessingScaler(0,1);
// Load Saved Neural Network
log.info("******LOAD TRAINED MODEL******");
// Details
// Where Model was saved
File locationToSave = new File("trained_mnist_model.zip");
// Restore the Model
MultiLayerNetwork model = ModelSerializer.restoreMultiLayerNetwork(locationToSave);
log.info("******Test your input image against model******");
//recordReader.reset();
recordReader.initialize(test);
DataSetIterator testIter = new RecordReaderDataSetIterator(recordReader,batchSize,1,outputNum);
scaler.fit(testIter);
testIter.setPreProcessor(scaler);
// Create Eval object with 10 possible classes
//Evaluation eval = new Evaluation(outputNum);
/*
If I wanted to zip up the two lists
for (int i = 0; i < strings.size(); i++) {
map.put(strings.get(i), integers.get(i));
}
*/
while(testIter.hasNext()){
DataSet next = testIter.next();
INDArray output = model.output(next.getFeatureMatrix());
//log.info(next.toString());
log.info("## THE FILE YOU CHOSE ## " + filechose);
log.info("## Neural Net's Prediction ##");
log.info("## list of probabilities per label ##");
log.info("## list of Labels in order ##");
log.info(output.toString());
log.info(labelList.toString());
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment