Skip to content

Instantly share code, notes, and snippets.

@bikashg
Created March 28, 2017 08:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bikashg/c76436121e719269174f47cf20ffea89 to your computer and use it in GitHub Desktop.
Save bikashg/c76436121e719269174f47cf20ffea89 to your computer and use it in GitHub Desktop.
public static void main(String[] args) throws IOException, InterruptedException {
int vocab_size = 14; // I set it manually here
String datasetBaseDir = "x_DataSet";
int numLinesToSkip = 1; // I used the first line in input/output files to place my own comment; so don't want to read it.
String fileDelimiter = ","; // All my input/output timesteps are single values for now.
int minCountExamples = 0;
int maxCountExamples = 1; // For now, saying that we have only 2 examples.
int miniBatchSize = 2; // For now, I have just 2 examples and I want to have one single batch (containing all those examples).
boolean useOneHot_encoderInput = true;
boolean useOneHot_decoderInput = true;
boolean useOneHot_decoderOutput = true;
MultiDataSetIterator myIterator = seq2seq_DataLoader.getDatSetIterator(miniBatchSize, useOneHot_encoderInput, useOneHot_decoderInput, useOneHot_decoderOutput, datasetBaseDir, vocab_size, numLinesToSkip, fileDelimiter, minCountExamples, maxCountExamples);
while(myIterator.hasNext()) {
MultiDataSet multi_dataSet = myIterator.next();
INDArray [] features = multi_dataSet.getFeatures(); // Since there are 2 inputs (encoder input and decoder input) , there will be 2 features in each multi_dataSet.
INDArray [] featuresMasks = multi_dataSet.getFeaturesMaskArrays(); // There will be a mask for each of those features.
for (int i=0; i<features.length; i++) { // The feature(=vocab_size times timesteps) matrix for encoder_input will be displayed first because that was added as first input when creating the dataset.
INDArray feat = features[i];
// INDArray mask = featuresMasks[i];
/********************** I have 2 examples, so this should have been printed 4 times. It does that when miniBatchSize is set to 1 but not when miniBatchSize = 2 *************/
System.out.println("Feature Transposed = \n" + feat.get(NDArrayIndex.point(0)).transpose() );
}
System.out.println("\nFinished displaying input for one example.\n");
}
}
public MultiDataSetIterator getDatSetIterator(int miniBatchSize, boolean useOneHot_encoderInput, boolean useOneHot_decoderInput, boolean useOneHot_decoderOutput, String datasetBaseDir, int vocab_size, int numLinesToSkip, String fileDelimiter, int minCountExamples, int maxCountExamples) throws IOException, InterruptedException {
SequenceRecordReader encoder_input_reader = new CSVSequenceRecordReader(numLinesToSkip,fileDelimiter);
encoder_input_reader.initialize(new NumberedFileInputSplit(datasetBaseDir+"/encoder_input/encoder_input_%d.csv",minCountExamples,maxCountExamples));
SequenceRecordReader decoder_input_Reader = new CSVSequenceRecordReader(numLinesToSkip,fileDelimiter);
decoder_input_Reader.initialize(new NumberedFileInputSplit(datasetBaseDir+"/decoder_input/decoder_input_%d.csv",minCountExamples,maxCountExamples));
SequenceRecordReader decoder_output_Reader = new CSVSequenceRecordReader(numLinesToSkip,fileDelimiter);
decoder_output_Reader.initialize(new NumberedFileInputSplit(datasetBaseDir+"/decoder_output/decoder_output_%d.csv",minCountExamples,maxCountExamples));
MultiDataSetIterator iterator;
RecordReaderMultiDataSetIterator.Builder datasetBuilder = new RecordReaderMultiDataSetIterator.Builder(miniBatchSize)
.addSequenceReader("encoderInput", encoder_input_reader)
.addSequenceReader("decoderInput", decoder_input_Reader)
.addSequenceReader("decoderOutput", decoder_output_Reader);
if (useOneHot_encoderInput) {
datasetBuilder.addInputOneHot("encoderInput", 0, vocab_size);
}
else {
datasetBuilder.addInput("encoderInput"); // Use the actual real value.
}
if (useOneHot_decoderInput) {
datasetBuilder.addInputOneHot("decoderInput", 0, vocab_size);
}
else {
datasetBuilder.addInput("decoderInput"); // Use the actual real value.
}
if (useOneHot_decoderOutput) {
datasetBuilder.addOutputOneHot("decoderOutput", 0, vocab_size);
}
else {
datasetBuilder.addOutput("decoderOutput"); // Use the actual real value.
}
datasetBuilder.sequenceAlignmentMode(RecordReaderMultiDataSetIterator.AlignmentMode.ALIGN_END); // this is needed to obtain the desired padding and masking for shorter timesteps example.
iterator = datasetBuilder.build();
return iterator;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment