Skip to content

Instantly share code, notes, and snippets.

@liweigu
Created May 25, 2018 02:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save liweigu/6729562e1a499347e548335b43546cff to your computer and use it in GitHub Desktop.
Save liweigu/6729562e1a499347e548335b43546cff to your computer and use it in GitHub Desktop.
RnnToFeedForward
// MultiLayerConfiguration
int OutDemension = 2;
ListBuilder listBuilder = builder.list();
int layerIndex = 0;
listBuilder.layer(layerIndex++, new GravesLSTM.Builder().activation(Activation.SOFTSIGN) // SOFTSIGN, TANH, RELU, SIGMOID
.nIn(inNum).nOut(hiddenCount).build());
listBuilder.layer(layerIndex++, new GravesLSTM.Builder().activation(Activation.SOFTSIGN)
.nIn(hiddenCount).nOut(hiddenCount).build());
listBuilder.layer(layerIndex++, new DenseLayer.Builder().activation(Activation.TANH)
.nIn(hiddenCount).nOut(denseLayerHiddenCount).build());
listBuilder.layer(layerIndex++, new DenseLayer.Builder().activation(Activation.TANH)
.nIn(denseLayerHiddenCount).nOut(denseLayerHiddenCount / 4).build());
// 1, RnnOutputLayer
// listBuilder.layer(layerIndex++,
// new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY).nIn(hiddenCount).nOut(OutDemension).build());
// 2, OutputLayer
listBuilder.layer(layerIndex++, new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
.nIn(denseLayerHiddenCount /4)
.nOut(OutDemension).build());
listBuilder.inputPreProcessor(2, new RnnToFeedForwardPreProcessor());
// predicting
// 1, using RnnOutputLayer
// INDArray predicted = net.rnnTimeStep(testData.getFeatureMatrix());
// 2, using OutputLayer
INDArray predicted = net.output(testData.getFeatureMatrix());
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment