Skip to content

Instantly share code, notes, and snippets.

@Broele
Last active January 23, 2018 18:19
Show Gist options
  • Save Broele/22b5f7e9bde28a8ca4b58c41ddd343e3 to your computer and use it in GitHub Desktop.
Save Broele/22b5f7e9bde28a8ca4b58c41ddd343e3 to your computer and use it in GitHub Desktop.
Why does this produce NaNs
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.graph.PreprocessorVertex;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToRnnPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.RnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.conditions.IsNaN;
public class Example {
private static final int nIn = 4;
private static final int nOut = 3;
public static ComputationGraph getGraph1() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.learningRate(0.01)
.graphBuilder()
.addInputs("features")
.addVertex("rnn2ffn",
new PreprocessorVertex(new RnnToFeedForwardPreProcessor()),
"features")
.addLayer("predict",
new DenseLayer.Builder()
.nIn(nIn)
.nOut(nOut)
.activation(Activation.RELU)
.build(),
"rnn2ffn"
)
.addVertex("ffn2rnn",
new PreprocessorVertex(new FeedForwardToRnnPreProcessor()),
"predict")
.addLayer("output",
new ActivationLayer.Builder()
.activation(Activation.IDENTITY)
.build(),
"ffn2rnn"
)
.setOutputs("output")
.backprop(true)
.build();
ComputationGraph graph = new ComputationGraph(conf);
graph.init();
return graph;
}
public static void main(String[] args) {
final int minibatch = 5;
final int seqLen = 6;
ComputationGraph model = getGraph1();
double[] param = new double[]{0.54, 0.31, 0.98, -0.30, -0.66, -0.19, -0.29, -0.62, 0.13, -0.32, 0.01, -0.03, 0.00, 0.00, 0.00};
for (int i = 0; i < param.length; i++)
model.params().putScalar(i, param[i]);
INDArray input = Nd4j.rand(new int[]{minibatch,nIn,seqLen}, 12);
INDArray expected = Nd4j.ones(new int[]{minibatch,nOut,seqLen});
for (int i = 0; i < 1000; i++) {
System.out.println("Params:");
System.out.println(model.params());
System.out.println();
INDArray output = model.outputSingle(input);
INDArray error = output.sub(expected);
// Compute Gradient
Gradient gradient = model.backpropGradient(error);
model.getUpdater().update(gradient, 0, minibatch);
// Update parameters
INDArray updateVector = gradient.gradient();
if (updateVector.cond(new IsNaN()).sumNumber().doubleValue() > 0) {
System.out.println("NaN-values spotted");
System.out.println("Input");
System.out.println(input);
System.out.println();
System.out.println("Output");
System.out.println(output);
System.out.println();
System.out.println("Error");
System.out.println(error);
System.out.println();
System.out.println("Gradient:");
System.out.println(updateVector);
System.out.println();
System.out.println("Params:");
System.out.println(model.params());
System.out.println();
return;
}
model.params().subi(updateVector);
}
}
}
@Broele
Copy link
Author

Broele commented Jan 23, 2018

Output:

Params:
[0.54,  0.31,  0.98,  -0.30,  -0.66,  -0.19,  -0.29,  -0.62,  0.13,  -0.32,  0.01,  -0.03,  0.00,  0.00,  0.00]

Params:
[0.54,  0.32,  0.98,  -0.29,  -0.66,  -0.19,  -0.29,  -0.62,  0.13,  -0.32,  0.01,  -0.03,  0.02,  0.00,  0.00]

NaN-values spotted
Input
[[[0.67,  0.61,  0.67,  0.36,  0.94,  0.05],  
  [0.92,  0.57,  0.65,  0.51,  0.46,  0.31],  
  [0.10,  0.37,  0.03,  0.49,  0.01,  0.11],  
  [0.85,  0.68,  0.90,  0.17,  0.92,  0.45]],  

 [[0.45,  0.68,  0.91,  0.16,  0.28,  0.69],  
  [0.23,  0.31,  0.58,  0.85,  0.68,  0.57],  
  [0.32,  0.28,  0.45,  0.50,  0.42,  0.85],  
  [0.53,  0.09,  0.43,  0.83,  0.74,  0.82]],  

 [[0.52,  0.30,  0.55,  0.32,  0.43,  0.97],  
  [0.62,  0.71,  0.39,  0.09,  0.84,  0.85],  
  [0.17,  0.26,  0.23,  0.15,  0.05,  0.35],  
  [0.79,  0.44,  0.80,  0.21,  0.47,  0.05]],  

 [[0.01,  0.05,  0.33,  0.89,  0.44,  0.30],  
  [0.44,  0.47,  0.46,  0.53,  0.90,  0.98],  
  [0.00,  0.80,  0.45,  0.17,  0.43,  0.39],  
  [1.00,  0.15,  0.73,  0.85,  0.73,  0.48]],  

 [[0.60,  0.27,  0.31,  0.86,  0.50,  0.41],  
  [0.96,  0.93,  0.48,  0.14,  0.29,  0.22],  
  [0.94,  0.69,  0.36,  0.81,  0.97,  0.65],  
  [0.00,  0.70,  0.14,  0.59,  0.75,  0.13]]]

Output
[[[0.52,  0.69,  0.35,  0.81,  0.41,  0.12],  
  [0.00,  0.00,  0.00,  0.00,  0.00,  0.00],  
  [0.00,  0.00,  0.00,  0.00,  0.00,  0.00]],  

 [[0.49,  0.72,  1.01,  0.62,  0.58,  1.17],  
  [0.00,  0.00,  0.00,  0.00,  0.00,  0.00],  
  [0.00,  0.00,  0.00,  0.00,  0.00,  0.00]],  

 [[0.44,  0.54,  0.43,  0.30,  0.43,  1.14],  
  [0.00,  0.00,  0.00,  0.00,  0.00,  0.00],  
  [0.00,  0.00,  0.00,  0.01,  0.00,  0.00]],  

 [[0.00,  0.93,  0.57,  0.59,  0.75,  0.73],  
  [0.00,  0.00,  0.00,  0.00,  0.00,  0.00],  
  [0.00,  0.00,  0.00,  0.00,  0.00,  0.00]],  

 [[1.56,  0.93,  0.65,  1.16,  1.12,  0.91],  
  [0.00,  0.00,  0.00,  0.00,  0.00,  0.00],  
  [0.00,  0.00,  0.00,  0.06,  0.00,  0.00]]]

Error
[[[-0.48,  -0.31,  -0.65,  -0.19,  -0.59,  -0.88],  
  [-1.00,  -1.00,  -1.00,  -1.00,  -1.00,  -1.00],  
  [-1.00,  -1.00,  -1.00,  -1.00,  -1.00,  -1.00]],  

 [[-0.51,  -0.28,  0.01,  -0.38,  -0.42,  0.17],  
  [-1.00,  -1.00,  -1.00,  -1.00,  -1.00,  -1.00],  
  [-1.00,  -1.00,  -1.00,  -1.00,  -1.00,  -1.00]],  

 [[-0.56,  -0.46,  -0.57,  -0.70,  -0.57,  0.14],  
  [-1.00,  -1.00,  -1.00,  -1.00,  -1.00,  -1.00],  
  [-1.00,  -1.00,  -1.00,  -0.99,  -1.00,  -1.00]],  

 [[-1.00,  -0.07,  -0.43,  -0.41,  -0.25,  -0.27],  
  [-1.00,  -1.00,  -1.00,  -1.00,  -1.00,  -1.00],  
  [-1.00,  -1.00,  -1.00,  -1.00,  -1.00,  -1.00]],  

 [[0.56,  -0.07,  -0.35,  0.16,  0.12,  -0.09],  
  [-1.00,  -1.00,  -1.00,  -1.00,  -1.00,  -1.00],  
  [-1.00,  -1.00,  -1.00,  -0.94,  -1.00,  -1.00]]]

Gradient:
[   �,     �,  -0.00,     �,     �,     �,  0.00,     �,     �,     �,  -0.00,     �,  -0.01,  0.00,  -0.01]

Params:
[0.54,  0.32,  0.98,  -0.29,  -0.66,  -0.19,  -0.29,  -0.62,  0.13,  -0.32,  0.01,  -0.03,  0.02,  0.00,  0.00]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment