Skip to content

Instantly share code, notes, and snippets.

@crockpotveggies
Created April 13, 2017 02:30
Show Gist options
  • Save crockpotveggies/74f7f62d1427c47841e9b1b9f906dc00 to your computer and use it in GitHub Desktop.
Save crockpotveggies/74f7f62d1427c47841e9b1b9f906dc00 to your computer and use it in GitHub Desktop.
PreSaving (preprocessing) a dataset for faster training in Deeplearning4j and DataVec.
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.berkeley.Pair;
import org.datavec.api.io.filters.RandomPathFilter;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.api.split.InputSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.image.transform.FlipImageTransform;
import org.datavec.image.transform.ImageTransform;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import java.io.File;
import java.util.LinkedList;
import java.util.List;
import java.util.Random;
/**
* PreSave a dataset of compressed images (GIF, JPG, PNG) for faster training.
*/
@Slf4j
public class PreSave {
// values to pass in from command line when compiled, esp running remotely
@Option(name = "--dataPath", usage = "Path to dataset", aliases = "-data")
public String dataPath = "";
@Option(name = "--batchSize", usage = "Batch size", aliases = "-batch")
public int batchSize = 72;
@Option(name = "--inputDimension", usage = "Dimension of input image", aliases = "-dim")
public int inputDimension = 118;
@Option(name = "--resizeDimension", usage = "Dimension of resize for random crop", aliases = "-resize")
public int resizeDimension = 96;
@Option(name = "--epochs", usage = "Number of epochs to process dataset", aliases = "-epochs")
public int epochs = 10;
@Option(name = "--trainDataPath", usage = "Folder to save training data", aliases = "-trainOutput")
public String trainDataPath = "";
@Option(name = "--testDataPath", usage = "Folder to save test data", aliases = "-testOutput")
public String testDataPath = "";
protected long seed = 42;
protected Random rng = new Random(seed);
public void run(String[] args) throws Exception {
// Parse command line arguments if they exist
CmdLineParser parser = new CmdLineParser(this);
try {
parser.parseArgument(args);
} catch (CmdLineException e) {
// handling of wrong arguments
System.err.println(e.getMessage());
parser.printUsage(System.err);
}
log.info("Load data....");
/**cd
* Data Setup -> organize and limit data file paths:
* - mainPath = path to image files
* - fileSplit = define basic dataset split with limits on format
* - pathFilter = define additional file load filter to limit size and balance batch content
**/
log.info("Loading paths....");
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
File mainPath = new File(dataPath);
FileSplit fileSplit = new FileSplit(mainPath, NativeImageLoader.ALLOWED_FORMATS, rng);
RandomPathFilter randomFilter = new RandomPathFilter(rng, NativeImageLoader.ALLOWED_FORMATS);
// BalancedPathFilter balancedFilter = new BalancedPathFilter(rng, NativeImageLoader.ALLOWED_FORMATS, labelMaker, 330000, 8100, 1,30);
/**
* Data Setup -> train test split
* - inputSplit = define train and test split
**/
InputSplit[] split;
log.info("Splitting data for debugging....");
split = fileSplit.sample(randomFilter, 0.998, 0.002);
InputSplit trainData = split[0];
InputSplit testData = split[1];
log.info("Total training images is "+trainData.length());
log.info("Total test images is "+testData.length());
// step 1
log.info("Initializing RecordReader and pipelines....");
List<Pair<ImageTransform, Double>> pipeline = new LinkedList<>();
pipeline.add(new Pair<>(new RandomCropTransform(resizeDimension,resizeDimension), 1.0));
pipeline.add(new Pair<>(new FlipImageTransform(1), 0.5));
ImageTransform combinedTransform = new ProbabilisticPipelineTransform(pipeline, false);
ImageRecordReader trainRR = new ImageRecordReader(inputDimension, inputDimension, 3, labelMaker, combinedTransform);
trainRR.initialize(trainData);
ImageRecordReader testRR = new ImageRecordReader(inputDimension, inputDimension, 3, labelMaker, combinedTransform);
testRR.setLabels(trainRR.getLabels());
testRR.initialize(testData);
int numClasses = trainRR.getLabels().size();
int testClasses = testRR.getLabels().size();
log.info("Total training labels: "+numClasses);
log.info("Total test labels: "+testClasses);
log.info("Creating RecordReader iterator....");
DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize);
DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, numClasses);
log.info("Load data....");
File trainFolder = new File(trainDataPath);
trainFolder.mkdirs();
File testFolder = new File(testDataPath);
testFolder.mkdirs();
log.info("Saving train data to " + trainFolder.getAbsolutePath() + " and test data to " + testFolder.getAbsolutePath());
//Track the indexes of the files being saved.
//These batch indexes are used for indexing which minibatch is being saved by the iterator.
int trainDataSaved = 0;
int testDataSaved = 0;
for(int i = 0; i < epochs; i++) {
while (trainIter.hasNext()) {
//note that we use testDataSaved as an index in to which batch this is for the file
trainIter.next().save(new File(trainFolder, "presave-train-" + trainDataSaved + ".bin"));
//^^^^^^^
//******************
//YOU NEED TO KNOW WHAT THIS IS.
//This is the index for the file saved.
//******************************************
trainDataSaved++;
}
trainIter.reset();
}
for(int i = 0; i < epochs; i++) {
while(testIter.hasNext()) {
//note that we use testDataSaved as an index in to which batch this is for the file
testIter.next().save(new File(testFolder,"presave-test-" + testDataSaved + ".bin"));
//^^^^^^^
//******************
//YOU NEED TO KNOW WHAT THIS IS.
//This is the index for the file saved.
//******************************************
testDataSaved++;
}
testIter.reset();
}
log.info("Finished pre saving test and train data");
}
public static void main(String[] args) throws Exception {
new PreSave().run(args);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment