Skip to content

Instantly share code, notes, and snippets.

@sirolf2009
Created June 15, 2016 22:11
Show Gist options
  • Save sirolf2009/0c4d93e9c22f069740e9eb2db68f6702 to your computer and use it in GitHub Desktop.
Save sirolf2009/0c4d93e9c22f069740e9eb2db68f6702 to your computer and use it in GitHub Desktop.
List<TimeSlice> slices
def nd4jDataSet() {
val input = Nd4j.zeros(1, 3, slices.length-2)
val label = Nd4j.zeros(1, 3, slices.length-2)
val red = slices.map[it.predictionRed]
val black = slices.map[it.predictionBlack]
val diff = slices.map[it.predictionDiff]
for(var i = 0; i < slices.length-2; i++) {
val currentRed = red.get(i)*10000
val nextRed = red.get(i+1)*10000
val currentBlack = black.get(i)*10000
val nextBlack = black.get(i+1)*10000
val currentDiff = diff.get(i)*10000
val nextDiff = diff.get(i+1)*10000
input.putScalar(#[0, 0, i], currentRed as int)
input.putScalar(#[0, 1, i], currentBlack as int)
input.putScalar(#[0, 2, i], currentDiff as int)
label.putScalar(#[0, 0, i], nextRed as int)
label.putScalar(#[0, 1, i], nextBlack as int)
label.putScalar(#[0, 2, i], nextDiff as int)
}
return new DataSet(input, label)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment