Skip to content

Instantly share code, notes, and snippets.

View koen-dejonghe's full-sized avatar

Koen Dejonghe koen-dejonghe

View GitHub Profile
@koen-dejonghe
koen-dejonghe / dl-from-scratch.ipynb
Created March 23, 2020 17:01
Deep learning from scratch with Scala
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
package akka.stream.alpakka.elasticsearch.scaladsl
import java.time.ZonedDateTime
import java.time.format.DateTimeFormatter
import akka.actor.Cancellable
import akka.stream.alpakka.elasticsearch.{ElasticsearchSourceSettings, ReadResult}
import akka.stream.scaladsl.{Flow, Source}
import com.typesafe.scalalogging.LazyLogging
import org.elasticsearch.client.RestClient
@koen-dejonghe
koen-dejonghe / NumscaLantern.scala
Last active April 3, 2019 10:12
Neural nets with numsca and continuations
package botkop.dp.lantern
import botkop.{numsca => ns}
import botkop.numsca.Tensor
import XTensor._
import scala.language.implicitConversions
import scala.util.continuations.{cps, reset, shift}
object NumscaLantern extends App {
epoch: 1 trn_loss: 0.601810 val_loss: 1.850695 val_score: 0.478634 duration: 8581ms
epoch: 2 trn_loss: 0.215092 val_loss: 0.242073 val_score: 0.931310 duration: 7829ms
epoch: 3 trn_loss: 0.164513 val_loss: 0.186792 val_score: 0.947484 duration: 7817ms
epoch: 4 trn_loss: 0.136912 val_loss: 0.164707 val_score: 0.954273 duration: 7898ms
epoch: 5 trn_loss: 0.118550 val_loss: 0.152412 val_score: 0.956270 duration: 7696ms
epoch: 6 trn_loss: 0.105179 val_loss: 0.144234 val_score: 0.958367 duration: 7817ms
epoch: 7 trn_loss: 0.094957 val_loss: 0.139751 val_score: 0.960164 duration: 7889ms
epoch: 8 trn_loss: 0.086851 val_loss: 0.132734 val_score: 0.962660 duration: 7811ms
epoch: 9 trn_loss: 0.079254 val_loss: 0.132945 val_score: 0.962660 duration: 7777ms
epoch: 10 trn_loss: 0.073101 val_loss: 0.130502 val_score: 0.964257 duration: 7792ms
epoch: 1 trn_loss: 0.373341 val_loss: 0.231301 score: 0.934600 duration: 19s
epoch: 2 trn_loss: 0.179583 val_loss: 0.188588 score: 0.946000 duration: 18s
epoch: 3 trn_loss: 0.138713 val_loss: 0.159787 score: 0.953000 duration: 18s
epoch: 4 trn_loss: 0.115682 val_loss: 0.148316 score: 0.955800 duration: 18s
epoch: 5 trn_loss: 0.097952 val_loss: 0.138213 score: 0.959300 duration: 18s
epoch: 6 trn_loss: 0.085970 val_loss: 0.128380 score: 0.963900 duration: 18s
epoch: 7 trn_loss: 0.077760 val_loss: 0.129295 score: 0.963100 duration: 18s
epoch: 8 trn_loss: 0.070618 val_loss: 0.131919 score: 0.964000 duration: 18s
epoch: 9 trn_loss: 0.063497 val_loss: 0.130191 score: 0.964600 duration: 18s
epoch: 10 trn_loss: 0.059064 val_loss: 0.138148 score: 0.963600 duration: 18s
def messageHandler: Receive = {
case Start =>
1 to trainingConcurrency foreach (_ => tdl ! nextTrainingBatch)
1 to validationConcurrency foreach (_ => vdl ! nextValidationBatch)
case Forward(yHat, y) =>
val l = loss(yHat, y)
l.backward()
wire.prev ! Backward(yHat.grad)
trainingLoss += l.data.squeeze()
def endPoint(y: Variable): Receive = {
case Forward(yHat) =>
val l = loss(yHat, y)
l.backward()
wire.prev ! Backward(yHat.grad)
trainingLoss += l.data.squeeze()
if (trainingDataIterator.hasNext) {
context become startPoint(TrainingDataIterator.next())
} else {
def startPoint(batch: (Variable, Variable)): Receive = {
case msg @ (Start | _: Backward) =>
if (msg == Start) { // new epoch
...
}
val (x, y) = batch
wire.next ! Forward(x)
context become endPoint(y)
}
def messageHandler(activations: List[(Variable, Variable)]): Receive = {
case Validate(x, y) =>
wire.next ! Validate(module(x), y)
case Forward(x, y) =>
val result = module(x)
wire.next ! Forward(result, y)
context become messageHandler(activations :+ (x, result))
def forwardHandle: Receive = {
case Forward(x) => // accept forward messages
val result = module(x) // calculate the result
wire.next ! Forward(result) // forward the result
context become backwardHandle(x, result) // change state to backward
}
def backwardHandle(input: Variable, output: Variable): Receive = {
case Backward(g) => // accept backward messages
optimizer.zeroGrad() // reset gradients to 0