Last active
January 18, 2018 07:58
-
-
Save koen-dejonghe/bdc38130eadf3f050fad436e4b63db3e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package scorch | |
import scorch.autograd._ | |
import botkop.numsca.Tensor | |
import botkop.{numsca => ns} | |
import org.scalactic.{Equality, TolerantNumerics} | |
import org.scalatest.{FlatSpec, Matchers} | |
class AutoGradSpec extends FlatSpec with Matchers { | |
"Autograd" should "derive mse" in { | |
val nOut = 4 | |
val minibatch = 3 | |
val input = Variable(ns.randn(minibatch, nOut)) | |
val label = Variable(ns.randn(minibatch, nOut)) | |
val diff = input - label | |
val sqDiff = diff * diff | |
val msePerEx = mean(sqDiff) | |
val avgMSE = mean(msePerEx) | |
avgMSE.shape shouldBe List(1, 1) | |
avgMSE.backward() | |
input.grad.get.shape shouldBe input.shape | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
@Test | |
public void testMseBackwards() { | |
SameDiff sd = SameDiff.create(); | |
int nOut = 4; | |
int minibatch = 3; | |
SDVariable input = sd.var("in", new int[]{-1,nOut}); | |
SDVariable label = sd.var("label", new int[]{-1, nOut}); | |
SDVariable diff = input.sub(label); | |
SDVariable sqDiff = diff.mul(diff); | |
SDVariable msePerEx = sd.mean("msePerEx", sqDiff, 1); | |
SDVariable avgMSE = sd.mean("loss", msePerEx, 0); | |
INDArray inputArr = Nd4j.rand(minibatch, nOut); | |
INDArray labelArr = Nd4j.rand(minibatch, nOut); | |
sd.associateArrayWithVariable(inputArr, input); | |
sd.associateArrayWithVariable(labelArr, label); | |
INDArray result = sd.execAndEndResult(); | |
assertEquals(1, result.length()); | |
Pair<Map<SDVariable, DifferentialFunction>, List<DifferentialFunction>> p = sd.execBackwards(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment