Created
March 28, 2017 08:35
-
-
Save bikashg/c76436121e719269174f47cf20ffea89 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public static void main(String[] args) throws IOException, InterruptedException { | |
int vocab_size = 14; // I set it manually here | |
String datasetBaseDir = "x_DataSet"; | |
int numLinesToSkip = 1; // I used the first line in input/output files to place my own comment; so don't want to read it. | |
String fileDelimiter = ","; // All my input/output timesteps are single values for now. | |
int minCountExamples = 0; | |
int maxCountExamples = 1; // For now, saying that we have only 2 examples. | |
int miniBatchSize = 2; // For now, I have just 2 examples and I want to have one single batch (containing all those examples). | |
boolean useOneHot_encoderInput = true; | |
boolean useOneHot_decoderInput = true; | |
boolean useOneHot_decoderOutput = true; | |
MultiDataSetIterator myIterator = seq2seq_DataLoader.getDatSetIterator(miniBatchSize, useOneHot_encoderInput, useOneHot_decoderInput, useOneHot_decoderOutput, datasetBaseDir, vocab_size, numLinesToSkip, fileDelimiter, minCountExamples, maxCountExamples); | |
while(myIterator.hasNext()) { | |
MultiDataSet multi_dataSet = myIterator.next(); | |
INDArray [] features = multi_dataSet.getFeatures(); // Since there are 2 inputs (encoder input and decoder input) , there will be 2 features in each multi_dataSet. | |
INDArray [] featuresMasks = multi_dataSet.getFeaturesMaskArrays(); // There will be a mask for each of those features. | |
for (int i=0; i<features.length; i++) { // The feature(=vocab_size times timesteps) matrix for encoder_input will be displayed first because that was added as first input when creating the dataset. | |
INDArray feat = features[i]; | |
// INDArray mask = featuresMasks[i]; | |
/********************** I have 2 examples, so this should have been printed 4 times. It does that when miniBatchSize is set to 1 but not when miniBatchSize = 2 *************/ | |
System.out.println("Feature Transposed = \n" + feat.get(NDArrayIndex.point(0)).transpose() ); | |
} | |
System.out.println("\nFinished displaying input for one example.\n"); | |
} | |
} | |
public MultiDataSetIterator getDatSetIterator(int miniBatchSize, boolean useOneHot_encoderInput, boolean useOneHot_decoderInput, boolean useOneHot_decoderOutput, String datasetBaseDir, int vocab_size, int numLinesToSkip, String fileDelimiter, int minCountExamples, int maxCountExamples) throws IOException, InterruptedException { | |
SequenceRecordReader encoder_input_reader = new CSVSequenceRecordReader(numLinesToSkip,fileDelimiter); | |
encoder_input_reader.initialize(new NumberedFileInputSplit(datasetBaseDir+"/encoder_input/encoder_input_%d.csv",minCountExamples,maxCountExamples)); | |
SequenceRecordReader decoder_input_Reader = new CSVSequenceRecordReader(numLinesToSkip,fileDelimiter); | |
decoder_input_Reader.initialize(new NumberedFileInputSplit(datasetBaseDir+"/decoder_input/decoder_input_%d.csv",minCountExamples,maxCountExamples)); | |
SequenceRecordReader decoder_output_Reader = new CSVSequenceRecordReader(numLinesToSkip,fileDelimiter); | |
decoder_output_Reader.initialize(new NumberedFileInputSplit(datasetBaseDir+"/decoder_output/decoder_output_%d.csv",minCountExamples,maxCountExamples)); | |
MultiDataSetIterator iterator; | |
RecordReaderMultiDataSetIterator.Builder datasetBuilder = new RecordReaderMultiDataSetIterator.Builder(miniBatchSize) | |
.addSequenceReader("encoderInput", encoder_input_reader) | |
.addSequenceReader("decoderInput", decoder_input_Reader) | |
.addSequenceReader("decoderOutput", decoder_output_Reader); | |
if (useOneHot_encoderInput) { | |
datasetBuilder.addInputOneHot("encoderInput", 0, vocab_size); | |
} | |
else { | |
datasetBuilder.addInput("encoderInput"); // Use the actual real value. | |
} | |
if (useOneHot_decoderInput) { | |
datasetBuilder.addInputOneHot("decoderInput", 0, vocab_size); | |
} | |
else { | |
datasetBuilder.addInput("decoderInput"); // Use the actual real value. | |
} | |
if (useOneHot_decoderOutput) { | |
datasetBuilder.addOutputOneHot("decoderOutput", 0, vocab_size); | |
} | |
else { | |
datasetBuilder.addOutput("decoderOutput"); // Use the actual real value. | |
} | |
datasetBuilder.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END); // this is needed to obtain the desired padding and masking for shorter timesteps example. | |
iterator = datasetBuilder.build(); | |
return iterator; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment