Skip to content

Instantly share code, notes, and snippets.

@koen-dejonghe
Created October 19, 2018 11:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koen-dejonghe/93928e934d8157ecebb4b7391c039d91 to your computer and use it in GitHub Desktop.
Save koen-dejonghe/93928e934d8157ecebb4b7391c039d91 to your computer and use it in GitHub Desktop.
def endPoint(y: Variable): Receive = {
case Forward(yHat) =>
val l = loss(yHat, y)
l.backward()
wire.prev ! Backward(yHat.grad)
trainingLoss += l.data.squeeze()
if (trainingDataIterator.hasNext) {
context become startPoint(TrainingDataIterator.next())
} else {
endOfEpoch()
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment