Skip to content

Instantly share code, notes, and snippets.

@koen-dejonghe
Created October 19, 2018 11:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koen-dejonghe/2b36b739e65d8fe8fdfd7f9b2f09aa1f to your computer and use it in GitHub Desktop.
Save koen-dejonghe/2b36b739e65d8fe8fdfd7f9b2f09aa1f to your computer and use it in GitHub Desktop.
def forwardHandle: Receive = {
case Forward(x) => // accept forward messages
val result = module(x) // calculate the result
wire.next ! Forward(result) // forward the result
context become backwardHandle(x, result) // change state to backward
}
def backwardHandle(input: Variable, output: Variable): Receive = {
case Backward(g) => // accept backward messages
optimizer.zeroGrad() // reset gradients to 0
output.backward(g) // calculate local gradients
wire.prev ! Backward(input.grad) // send local gradients
optimizer.step() // update the weights
context become forwardHandle // change state to forward
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment