Skip to content

Instantly share code, notes, and snippets.

@koen-dejonghe
Created October 19, 2018 11:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koen-dejonghe/73467ed5694371cdcb38df3d68a2feb8 to your computer and use it in GitHub Desktop.
Save koen-dejonghe/73467ed5694371cdcb38df3d68a2feb8 to your computer and use it in GitHub Desktop.
def messageHandler(activations: List[(Variable, Variable)]): Receive = {
case Validate(x, y) =>
wire.next ! Validate(module(x), y)
case Forward(x, y) =>
val result = module(x)
wire.next ! Forward(result, y)
context become messageHandler(activations :+ (x, result))
case Backward(g) =>
activations match {
case (input, output) :: tail =>
optimizer.zeroGrad()
output.backward(g)
wire.prev ! Backward(input.grad)
optimizer.step()
context become messageHandler(tail)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment