Skip to content

Instantly share code, notes, and snippets.

@crockpotveggies
Last active October 27, 2017 00:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save crockpotveggies/a061311b88cf21d3e662f7834f3a9b03 to your computer and use it in GitHub Desktop.
Save crockpotveggies/a061311b88cf21d3e662f7834f3a9b03 to your computer and use it in GitHub Desktop.
Sequence to Sequence Autoencoder Preprocessor
object Preprocessor extends Serializable {
class Seq2SeqAutoencoderPreProcessor extends MultiDataSetPreProcessor {
override def preProcess(mds: MultiDataSet): Unit = {
val input: INDArray = mds.getFeatures(0)
val features: Array[INDArray] = Array.ofDim[INDArray](2)
val labels: Array[INDArray] = Array.ofDim[INDArray](1)
features(0) = input
val mb: Int = input.size(0)
val nClasses: Int = input.size(1)
val origMaxTsLength: Int = input.size(2)
val goStopTokenPos: Int = nClasses
//1 new class, for GO/STOP. And one new time step for it also
val newShape: Array[Int] = Array(mb, nClasses + 1, origMaxTsLength + 1)
features(1) = Nd4j.create(newShape:_*)
labels(0) = Nd4j.create(newShape:_*)
//Create features. Append existing at time 1 to end. Put GO token at time 0
features(1).put(Array[INDArrayIndex](all(), interval(0, input.size(1)), interval(1, newShape(2))), input)
//Set GO token
features(1).get(all(), point(goStopTokenPos), all()).assign(1)
//Create labels. Append existing at time 0 to end-1. Put STOP token at last time step - **Accounting for variable length / masks**
labels(0).put(Array[INDArrayIndex](all(), interval(0, input.size(1)), interval(0, newShape(2) - 1)), input)
var lastTimeStepPos: Array[Int] = null
if (mds.getFeaturesMaskArray(0) == null) {//No masks
lastTimeStepPos = Array.ofDim[Int](input.size(0))
for (i <- 0 until lastTimeStepPos.length) {
lastTimeStepPos(i) = input.size(2) - 1
}
} else {
val fm: INDArray = mds.getFeaturesMaskArray(0)
val lastIdx: INDArray = BooleanIndexing.lastIndex(fm, Conditions.notEquals(0), 1)
lastTimeStepPos = lastIdx.data().asInt()
}
for (i <- 0 until lastTimeStepPos.length) {
labels(0).putScalar(i, goStopTokenPos, lastTimeStepPos(i), 1.0)
}
//In practice: Just need to append an extra 1 at the start (as all existing time series are now 1 step longer)
var featureMasks: Array[INDArray] = null
var labelsMasks: Array[INDArray] = null
if (mds.getFeaturesMaskArray(0) != null) {//Masks are present - variable length
featureMasks = Array.ofDim[INDArray](2)
featureMasks(0) = mds.getFeaturesMaskArray(0)
labelsMasks = Array.ofDim[INDArray](1)
val newMask: INDArray = Nd4j.hstack(Nd4j.ones(mb, 1), mds.getFeaturesMaskArray(0))
// println(mds.getFeaturesMaskArray(0).shape())
// println(newMask.shape())
featureMasks(1) = newMask
labelsMasks(0) = newMask
} else {
//All same length
featureMasks = null
labelsMasks = null
}
//Same for labels
mds.setFeatures(features)
mds.setLabels(labels)
mds.setFeaturesMaskArrays(featureMasks)
mds.setLabelsMaskArray(labelsMasks)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment