Skip to content

Instantly share code, notes, and snippets.

@eraly
Created June 24, 2016 00:44
Show Gist options
  • Save eraly/2352bfce4fa61be3103c7b4cabb1dc9d to your computer and use it in GitHub Desktop.
Save eraly/2352bfce4fa61be3103c7b4cabb1dc9d to your computer and use it in GitHub Desktop.
@Override
public List<DataSet> asList() {
List<DataSet> list = new ArrayList<>(numExamples());
// Preserving the dimension of the dataset - essentially a minibatch size of 1
int [] featureShape = getFeatures().shape();
featureShape[0] = 1
int [] labelShape = getLabels().shape();
labelShape[0] = 1;
for (int i = 1; i < numExamples(); i++) {
INDArray featuresHere = getFeatures().slice(i).reshape(featureShape);
INDArray labelsHere = getLabels().slice(i).reshape(labelShape);
INDArray featureMaskHere = featuresMask != null ? featuresMask.slice(i) : null;
featureMaskHere.reshape(featureShape);
INDArray labelMaskHere = labelsMask != null ? labelsMask.slice(i) : null;
labelMaskHere.repeat(labelShape);
list.add(new DataSet(featuresHere,labelsHere,featureMaskHere,labelMaskHere);
}
return list;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment