Skip to content

Instantly share code, notes, and snippets.

@Dris101
Created July 2, 2020 14:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Dris101/c2663600f3d35733cef2b9331d424ec7 to your computer and use it in GitHub Desktop.
Save Dris101/c2663600f3d35733cef2b9331d424ec7 to your computer and use it in GitHub Desktop.
import scala.collection.JavaConverters._
Nd4j.getRandom().setSeed(1234L)
val wInit = Nd4j.rand(5, 7)
val testInit = Nd4j.rand(7, 3, 4, 5)
// Not permuted. Works
val sd1 = SameDiff.create()
var w1 = sd1.`var`("w", wInit)
val testInputNotPermuted = sd1.constant("test_input_not", testInit)
val result1 = sd1.tensorMmul("result", testInputNotPermuted, w1, Array(3), 0)
sd1.setLossVariables(result1)
val output1 = sd1.output(Map.empty[String, INDArray].asJava, "result").get("result")
val grad1 = sd1.calculateGradients(null, "w").get("w") // Grads fine here
// Permuted. Fails
val sd2 = SameDiff.create()
var w2 = sd2.`var`("w", wInit)
// Permute outside SameDiff and then back again inside to recover original
val testInputPermuted = sd2.constant("test_input_permuted", testInit.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
val result2 = sd2.tensorMmul("result", testInputPermuted, w2, Array(3), 0)
sd2.setLossVariables(result2)
val output2 = sd2.output(Map.empty[String, INDArray].asJava, "result").get("result")
assert(output1 == output2) // Passes
// Throws:-
// java.lang.NullPointerException:
// at org.nd4j.linalg.api.ops.impl.reduce.TensorMmul.doDiff(TensorMmul.java:146)
// at org.nd4j.autodiff.functions.DifferentialFunction.diff(DifferentialFunction.java:559)
val grad2 = sd2.calculateGradients(null, "w").get("w")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment