Created
May 25, 2018 02:42
-
-
Save liweigu/6729562e1a499347e548335b43546cff to your computer and use it in GitHub Desktop.
RnnToFeedForward
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// MultiLayerConfiguration | |
int OutDemension = 2; | |
ListBuilder listBuilder = builder.list(); | |
int layerIndex = 0; | |
listBuilder.layer(layerIndex++, new GravesLSTM.Builder().activation(Activation.SOFTSIGN) // SOFTSIGN, TANH, RELU, SIGMOID | |
.nIn(inNum).nOut(hiddenCount).build()); | |
listBuilder.layer(layerIndex++, new GravesLSTM.Builder().activation(Activation.SOFTSIGN) | |
.nIn(hiddenCount).nOut(hiddenCount).build()); | |
listBuilder.layer(layerIndex++, new DenseLayer.Builder().activation(Activation.TANH) | |
.nIn(hiddenCount).nOut(denseLayerHiddenCount).build()); | |
listBuilder.layer(layerIndex++, new DenseLayer.Builder().activation(Activation.TANH) | |
.nIn(denseLayerHiddenCount).nOut(denseLayerHiddenCount / 4).build()); | |
// 1, RnnOutputLayer | |
// listBuilder.layer(layerIndex++, | |
// new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nIn(hiddenCount).nOut(OutDemension).build()); | |
// 2, OutputLayer | |
listBuilder.layer(layerIndex++, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY) | |
.nIn(denseLayerHiddenCount /4) | |
.nOut(OutDemension).build()); | |
listBuilder.inputPreProcessor(2, new RnnToFeedForwardPreProcessor()); | |
// predicting | |
// 1, using RnnOutputLayer | |
// INDArray predicted = net.rnnTimeStep(testData.getFeatureMatrix()); | |
// 2, using OutputLayer | |
INDArray predicted = net.output(testData.getFeatureMatrix()); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment