Skip to content

Instantly share code, notes, and snippets.

@deerishi
Created November 2, 2015 21:01
Show Gist options
  • Save deerishi/64c4cbc13724114b0d34 to your computer and use it in GitHub Desktop.
Save deerishi/64c4cbc13724114b0d34 to your computer and use it in GitHub Desktop.
image loading in dl4j
package com.chil;
import org.canova.api.records.reader.RecordReader;
import org.canova.api.split.FileSplit;
import org.canova.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.canova.RecordReaderDataSetIterator;
import org.deeplearning4j.eval.Evaluation;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.override.ConfOverride;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.plot.PlotFilters;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.deeplearning4j.plot.PlotFilters.*;
import org.slf4j.LoggerFactory;
import scala.Int;
/**
* Hello world!
*
*/
public class App
{
public static void main( String[] args ) throws Exception
{
String labeledPath = System.getProperty("user.home")+"/test";
List<String> labels = new ArrayList<String>();
//traverse dataset to get each label
for(File f : new File(labeledPath).listFiles())
{
labels.add(f.getName());
}
RecordReader recordReader = new ImageRecordReader(28, 28,3, true, labels);
recordReader.initialize(new FileSplit(new File(labeledPath)));
DataSetIterator iter = new RecordReaderDataSetIterator(recordReader, 784*3,labels.size());
Integer i=1;
while(iter.hasNext()){
DataSet next = iter.next();
System.out.println("Reading "+i+" = "+next.asList()+" added");
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment