Last active
January 23, 2018 18:19
-
-
Save Broele/22b5f7e9bde28a8ca4b58c41ddd343e3 to your computer and use it in GitHub Desktop.
Why does this produce NaNs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; | |
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; | |
import org.deeplearning4j.nn.conf.graph.PreprocessorVertex; | |
import org.deeplearning4j.nn.conf.layers.ActivationLayer; | |
import org.deeplearning4j.nn.conf.layers.DenseLayer; | |
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer; | |
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor; | |
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor; | |
import org.deeplearning4j.nn.gradient.Gradient; | |
import org.deeplearning4j.nn.graph.ComputationGraph; | |
import org.nd4j.linalg.activations.Activation; | |
import org.nd4j.linalg.api.ndarray.INDArray; | |
import org.nd4j.linalg.factory.Nd4j; | |
import org.nd4j.linalg.indexing.conditions.IsNaN; | |
public class Example { | |
private static final int nIn = 4; | |
private static final int nOut = 3; | |
public static ComputationGraph getGraph1() { | |
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() | |
.learningRate(0.01) | |
.graphBuilder() | |
.addInputs("features") | |
.addVertex("rnn2ffn", | |
new PreprocessorVertex(new RnnToFeedForwardPreProcessor()), | |
"features") | |
.addLayer("predict", | |
new DenseLayer.Builder() | |
.nIn(nIn) | |
.nOut(nOut) | |
.activation(Activation.RELU) | |
.build(), | |
"rnn2ffn" | |
) | |
.addVertex("ffn2rnn", | |
new PreprocessorVertex(new FeedForwardToRnnPreProcessor()), | |
"predict") | |
.addLayer("output", | |
new ActivationLayer.Builder() | |
.activation(Activation.IDENTITY) | |
.build(), | |
"ffn2rnn" | |
) | |
.setOutputs("output") | |
.backprop(true) | |
.build(); | |
ComputationGraph graph = new ComputationGraph(conf); | |
graph.init(); | |
return graph; | |
} | |
public static void main(String[] args) { | |
final int minibatch = 5; | |
final int seqLen = 6; | |
ComputationGraph model = getGraph1(); | |
double[] param = new double[]{0.54, 0.31, 0.98, -0.30, -0.66, -0.19, -0.29, -0.62, 0.13, -0.32, 0.01, -0.03, 0.00, 0.00, 0.00}; | |
for (int i = 0; i < param.length; i++) | |
model.params().putScalar(i, param[i]); | |
INDArray input = Nd4j.rand(new int[]{minibatch,nIn,seqLen}, 12); | |
INDArray expected = Nd4j.ones(new int[]{minibatch,nOut,seqLen}); | |
for (int i = 0; i < 1000; i++) { | |
System.out.println("Params:"); | |
System.out.println(model.params()); | |
System.out.println(); | |
INDArray output = model.outputSingle(input); | |
INDArray error = output.sub(expected); | |
// Compute Gradient | |
Gradient gradient = model.backpropGradient(error); | |
model.getUpdater().update(gradient, 0, minibatch); | |
// Update parameters | |
INDArray updateVector = gradient.gradient(); | |
if (updateVector.cond(new IsNaN()).sumNumber().doubleValue() > 0) { | |
System.out.println("NaN-values spotted"); | |
System.out.println("Input"); | |
System.out.println(input); | |
System.out.println(); | |
System.out.println("Output"); | |
System.out.println(output); | |
System.out.println(); | |
System.out.println("Error"); | |
System.out.println(error); | |
System.out.println(); | |
System.out.println("Gradient:"); | |
System.out.println(updateVector); | |
System.out.println(); | |
System.out.println("Params:"); | |
System.out.println(model.params()); | |
System.out.println(); | |
return; | |
} | |
model.params().subi(updateVector); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Output: