Skip to content

Instantly share code, notes, and snippets.

@jumpingfella
Created May 28, 2018 16:23
Show Gist options
  • Save jumpingfella/2ca35ed2298561207b3496d30f526ad2 to your computer and use it in GitHub Desktop.
Save jumpingfella/2ca35ed2298561207b3496d30f526ad2 to your computer and use it in GitHub Desktop.
LSTM + RNN network for time series
public class RecurrentNets {
private static final double learningRate = 0.1;
private static final int seed = 12345;
private static final int nHidden = 200;
private static final int truncatedBPTTLength = 22;
public static MultiLayerNetwork buildLstmNetworks(int nIn, int nOut) {
//Set up network configuration:
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(seed)
.l2(0.001)
.weightInit(WeightInit.XAVIER)
.updater(new RmsProp(learningRate))
.list()
.layer(0, new GravesLSTM.Builder().nIn(nIn).nOut(nHidden)
.activation(Activation.SOFTSIGN).build())
.layer(1, new GravesLSTM.Builder().nIn(nHidden).nOut(nHidden)
.activation(Activation.SOFTSIGN).build())
.layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.IDENTITY)
.nIn(nHidden).nOut(nOut).build())
.backpropType(BackpropType.TruncatedBPTT).tBPTTForwardLength(truncatedBPTTLength).tBPTTBackwardLength(truncatedBPTTLength)
.pretrain(false).backprop(true)
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
net.setListeners(new ScoreIterationListener(100));
return net;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment