Skip to content

Instantly share code, notes, and snippets.

@koen-dejonghe
Created October 19, 2018 11:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koen-dejonghe/330f583b72defb61a19f36ac6ade9c55 to your computer and use it in GitHub Desktop.
Save koen-dejonghe/330f583b72defb61a19f36ac6ade9c55 to your computer and use it in GitHub Desktop.
def messageHandler: Receive = {
case Start =>
1 to trainingConcurrency foreach (_ => tdl ! nextTrainingBatch)
1 to validationConcurrency foreach (_ => vdl ! nextValidationBatch)
case Forward(yHat, y) =>
val l = loss(yHat, y)
l.backward()
wire.prev ! Backward(yHat.grad)
trainingLoss += l.data.squeeze()
numTrainingBatches += 1
case Backward(_) =>
tdl ! nextTrainingBatch
case Validate(x, y) =>
numValidationBatches += 1
validationLoss += loss(x, y).data.squeeze()
validationScore += evaluator(x, y)
if (numValidationBatches < validationDataLoader.numBatches)
vdl ! nextValidationBatch
case Epoch("training", epoch, duration) =>
// report scores and losses, and reset counters
...
1 to validationConcurrency foreach (_ => vdl ! nextValidationBatch)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment