Created
April 13, 2017 02:30
-
-
Save crockpotveggies/74f7f62d1427c47841e9b1b9f906dc00 to your computer and use it in GitHub Desktop.
PreSaving (preprocessing) a dataset for faster training in Deeplearning4j and DataVec.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import lombok.extern.slf4j.Slf4j; | |
import org.datavec.api.berkeley.Pair; | |
import org.datavec.api.io.filters.RandomPathFilter; | |
import org.datavec.api.io.labels.ParentPathLabelGenerator; | |
import org.datavec.api.split.FileSplit; | |
import org.datavec.api.split.InputSplit; | |
import org.datavec.image.loader.NativeImageLoader; | |
import org.datavec.image.recordreader.ImageRecordReader; | |
import org.datavec.image.transform.FlipImageTransform; | |
import org.datavec.image.transform.ImageTransform; | |
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator; | |
import org.kohsuke.args4j.CmdLineException; | |
import org.kohsuke.args4j.CmdLineParser; | |
import org.kohsuke.args4j.Option; | |
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; | |
import java.io.File; | |
import java.util.LinkedList; | |
import java.util.List; | |
import java.util.Random; | |
/** | |
* PreSave a dataset of compressed images (GIF, JPG, PNG) for faster training. | |
*/ | |
@Slf4j | |
public class PreSave { | |
// values to pass in from command line when compiled, esp running remotely | |
@Option(name = "--dataPath", usage = "Path to dataset", aliases = "-data") | |
public String dataPath = ""; | |
@Option(name = "--batchSize", usage = "Batch size", aliases = "-batch") | |
public int batchSize = 72; | |
@Option(name = "--inputDimension", usage = "Dimension of input image", aliases = "-dim") | |
public int inputDimension = 118; | |
@Option(name = "--resizeDimension", usage = "Dimension of resize for random crop", aliases = "-resize") | |
public int resizeDimension = 96; | |
@Option(name = "--epochs", usage = "Number of epochs to process dataset", aliases = "-epochs") | |
public int epochs = 10; | |
@Option(name = "--trainDataPath", usage = "Folder to save training data", aliases = "-trainOutput") | |
public String trainDataPath = ""; | |
@Option(name = "--testDataPath", usage = "Folder to save test data", aliases = "-testOutput") | |
public String testDataPath = ""; | |
protected long seed = 42; | |
protected Random rng = new Random(seed); | |
public void run(String[] args) throws Exception { | |
// Parse command line arguments if they exist | |
CmdLineParser parser = new CmdLineParser(this); | |
try { | |
parser.parseArgument(args); | |
} catch (CmdLineException e) { | |
// handling of wrong arguments | |
System.err.println(e.getMessage()); | |
parser.printUsage(System.err); | |
} | |
log.info("Load data...."); | |
/**cd | |
* Data Setup -> organize and limit data file paths: | |
* - mainPath = path to image files | |
* - fileSplit = define basic dataset split with limits on format | |
* - pathFilter = define additional file load filter to limit size and balance batch content | |
**/ | |
log.info("Loading paths...."); | |
ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator(); | |
File mainPath = new File(dataPath); | |
FileSplit fileSplit = new FileSplit(mainPath, NativeImageLoader.ALLOWED_FORMATS, rng); | |
RandomPathFilter randomFilter = new RandomPathFilter(rng, NativeImageLoader.ALLOWED_FORMATS); | |
// BalancedPathFilter balancedFilter = new BalancedPathFilter(rng, NativeImageLoader.ALLOWED_FORMATS, labelMaker, 330000, 8100, 1,30); | |
/** | |
* Data Setup -> train test split | |
* - inputSplit = define train and test split | |
**/ | |
InputSplit[] split; | |
log.info("Splitting data for debugging...."); | |
split = fileSplit.sample(randomFilter, 0.998, 0.002); | |
InputSplit trainData = split[0]; | |
InputSplit testData = split[1]; | |
log.info("Total training images is "+trainData.length()); | |
log.info("Total test images is "+testData.length()); | |
// step 1 | |
log.info("Initializing RecordReader and pipelines...."); | |
List<Pair<ImageTransform, Double>> pipeline = new LinkedList<>(); | |
pipeline.add(new Pair<>(new RandomCropTransform(resizeDimension,resizeDimension), 1.0)); | |
pipeline.add(new Pair<>(new FlipImageTransform(1), 0.5)); | |
ImageTransform combinedTransform = new ProbabilisticPipelineTransform(pipeline, false); | |
ImageRecordReader trainRR = new ImageRecordReader(inputDimension, inputDimension, 3, labelMaker, combinedTransform); | |
trainRR.initialize(trainData); | |
ImageRecordReader testRR = new ImageRecordReader(inputDimension, inputDimension, 3, labelMaker, combinedTransform); | |
testRR.setLabels(trainRR.getLabels()); | |
testRR.initialize(testData); | |
int numClasses = trainRR.getLabels().size(); | |
int testClasses = testRR.getLabels().size(); | |
log.info("Total training labels: "+numClasses); | |
log.info("Total test labels: "+testClasses); | |
log.info("Creating RecordReader iterator...."); | |
DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRR, batchSize); | |
DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, batchSize, 1, numClasses); | |
log.info("Load data...."); | |
File trainFolder = new File(trainDataPath); | |
trainFolder.mkdirs(); | |
File testFolder = new File(testDataPath); | |
testFolder.mkdirs(); | |
log.info("Saving train data to " + trainFolder.getAbsolutePath() + " and test data to " + testFolder.getAbsolutePath()); | |
//Track the indexes of the files being saved. | |
//These batch indexes are used for indexing which minibatch is being saved by the iterator. | |
int trainDataSaved = 0; | |
int testDataSaved = 0; | |
for(int i = 0; i < epochs; i++) { | |
while (trainIter.hasNext()) { | |
//note that we use testDataSaved as an index in to which batch this is for the file | |
trainIter.next().save(new File(trainFolder, "presave-train-" + trainDataSaved + ".bin")); | |
//^^^^^^^ | |
//****************** | |
//YOU NEED TO KNOW WHAT THIS IS. | |
//This is the index for the file saved. | |
//****************************************** | |
trainDataSaved++; | |
} | |
trainIter.reset(); | |
} | |
for(int i = 0; i < epochs; i++) { | |
while(testIter.hasNext()) { | |
//note that we use testDataSaved as an index in to which batch this is for the file | |
testIter.next().save(new File(testFolder,"presave-test-" + testDataSaved + ".bin")); | |
//^^^^^^^ | |
//****************** | |
//YOU NEED TO KNOW WHAT THIS IS. | |
//This is the index for the file saved. | |
//****************************************** | |
testDataSaved++; | |
} | |
testIter.reset(); | |
} | |
log.info("Finished pre saving test and train data"); | |
} | |
public static void main(String[] args) throws Exception { | |
new PreSave().run(args); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment