Skip to content

Instantly share code, notes, and snippets.

@koen-dejonghe
Created March 23, 2020 17:01
Show Gist options
  • Save koen-dejonghe/00390ec8e3a3a945ff1925a116e9360a to your computer and use it in GitHub Desktop.
Save koen-dejonghe/00390ec8e3a3a945ff1925a116e9360a to your computer and use it in GitHub Desktop.
Deep learning from scratch with Scala
Display the source blob
Display the rendered blob
Raw
{
"metadata" : {
"config" : {
"dependencies" : {
"scala" : [
"be.botkop:scorch_2.12:0.1.1"
]
},
"exclusions" : [
],
"repositories" : [
],
"sparkConfig" : {
},
"env" : {
"OMP_NUM_THREADS" : "1"
}
}
},
"nbformat" : 4,
"nbformat_minor" : 0,
"cells" : [
{
"cell_type" : "markdown",
"execution_count" : 0,
"metadata" : {
"language" : "text"
},
"language" : "text",
"source" : [
"# Deep learning from scratch in Scala!\n",
"\n",
"\n"
],
"outputs" : [
]
},
{
"cell_type" : "code",
"execution_count" : 1,
"metadata" : {
"cell.metadata.exec_info" : {
"startTs" : 1584980960454,
"endTs" : 1584980960989
},
"language" : "scala"
},
"language" : "scala",
"source" : [
"import botkop.{numsca => ns}\r\n",
"import ns.Tensor\r\n",
"import scala.language.implicitConversions"
],
"outputs" : [
]
},
{
"cell_type" : "markdown",
"execution_count" : 2,
"metadata" : {
"language" : "text"
},
"language" : "text",
"source" : [
"### Variables & Functions\n",
"\n"
],
"outputs" : [
]
},
{
"cell_type" : "code",
"execution_count" : 3,
"metadata" : {
"jupyter.outputs_hidden" : true,
"cell.metadata.exec_info" : {
"startTs" : 1584980961006,
"endTs" : 1584980962545
},
"language" : "scala"
},
"language" : "scala",
"source" : [
"trait Function {\r\n",
" def forward(): Variable\r\n",
" def backward(g: Tensor): Unit\r\n",
"}\r\n",
"\r\n",
"case class Variable(data: Tensor, f: Option[Function] = None) {\r\n",
"\r\n",
" lazy val g: Tensor = ns.zerosLike(data)\r\n",
"\r\n",
" def backward(gradOutput: Tensor = ns.ones(data.shape)): Unit = {\r\n",
" val ug = unbroadcast(gradOutput)\r\n",
" g += ug\r\n",
" for (gf <- f) gf.backward(ug)\r\n",
" }\r\n",
"\r\n",
" def relu(): Variable = Threshold(this, 0.0).forward()\r\n",
" def t(): Variable = Transpose(this).forward()\r\n",
"\r\n",
" def +(other: Variable): Variable = Add(this, other).forward()\r\n",
" def dot(other: Variable): Variable = Dot(this, other).forward()\r\n",
"\r\n",
" def ~>(f: Variable => Variable): Variable = f(this)\r\n",
" def shape: List[Int] = data.shape.toList\r\n",
"\r\n",
" def unbroadcast(t: Tensor): Tensor =\r\n",
" if (t.shape.sameElements(data.shape)) t\r\n",
" else data.shape.zip(t.shape).zipWithIndex.foldLeft(t) {\r\n",
" case (d: Tensor, ((oi, ni), i)) if oi == ni => \r\n",
" d\r\n",
" case (d: Tensor, ((oi, ni), i)) if oi == 1 => \r\n",
" ns.sum(d, axis = i)\r\n",
" case _ =>\r\n",
" throw new Exception(\r\n",
" s\"unable to reduce broadcasted shape ${t.shape.toList} as ${data.shape.toList}\")\r\n",
" }\r\n",
"}\r\n",
"\r\n",
"case class Transpose(v: Variable) extends Function {\r\n",
" override def forward(): Variable = Variable(v.data.transpose, Some(this))\r\n",
" override def backward(gradOutput: Tensor): Unit =\r\n",
" v.backward(gradOutput.transpose)\r\n",
"}\r\n",
"\r\n",
"case class Threshold(x: Variable, d: Double) extends Function {\r\n",
" override def forward(): Variable = Variable(ns.maximum(x.data, d), Some(this))\r\n",
" override def backward(gradOutput: Tensor): Unit = {\r\n",
" x.backward(gradOutput * (x.data > d))\r\n",
" }\r\n",
"}\r\n",
"\r\n",
"case class Add(v1: Variable, v2: Variable) extends Function {\r\n",
" def forward(): Variable = Variable(v1.data + v2.data, f = Some(this))\r\n",
" def backward(g: Tensor): Unit = {\r\n",
" v1.backward(g)\r\n",
" v2.backward(g)\r\n",
" }\r\n",
"}\r\n",
"\r\n",
"case class Dot(v1: Variable, v2: Variable) extends Function {\r\n",
" val w: Tensor = v1.data\r\n",
" val x: Tensor = v2.data\r\n",
" override def forward(): Variable = Variable(w dot x, f = Some(this))\r\n",
" override def backward(g: Tensor): Unit = {\r\n",
" val dw = g dot x.T\r\n",
" val dx = w.T dot g\r\n",
" v1.backward(dw)\r\n",
" v2.backward(dx)\r\n",
" }\r\n",
"}\r\n",
"\r\n",
"object Variable {\r\n",
" def apply(ds: Double*): Variable = Variable(Tensor(ds:_*))\r\n",
"}"
],
"outputs" : [
]
},
{
"cell_type" : "markdown",
"execution_count" : 4,
"metadata" : {
"language" : "text"
},
"language" : "text",
"source" : [
"### Module\n",
"\n"
],
"outputs" : [
]
},
{
"cell_type" : "code",
"execution_count" : 5,
"metadata" : {
"cell.metadata.exec_info" : {
"startTs" : 1584980962547,
"endTs" : 1584980963005
},
"language" : "scala"
},
"language" : "scala",
"source" : [
"abstract class Module(localParameters: Seq[Variable] = Nil) {\r\n",
" // by default, obtain submodules through introspection\r\n",
" lazy val subModules: Seq[Module] =\r\n",
" this.getClass.getDeclaredFields.flatMap { f =>\r\n",
" f setAccessible true\r\n",
" f.get(this) match {\r\n",
" case module: Module => Some(module)\r\n",
" case _ => None\r\n",
" }\r\n",
" }\r\n",
"\r\n",
" def parameters: Seq[Variable] = localParameters ++ subModules.flatMap(_.parameters)\r\n",
" def gradients: Seq[Tensor] = parameters.map(_.g)\r\n",
" def zeroGrad(): Unit = parameters.foreach(p => p.g := 0)\r\n",
"\r\n",
" def forward(x: Variable): Variable\r\n",
" def apply(x: Variable): Variable = forward(x)\r\n",
"}\r\n"
],
"outputs" : [
]
},
{
"cell_type" : "markdown",
"execution_count" : 6,
"metadata" : {
"language" : "text"
},
"language" : "text",
"source" : [
"#### Linear Module\n",
"\n"
],
"outputs" : [
]
},
{
"cell_type" : "code",
"execution_count" : 7,
"metadata" : {
"cell.metadata.exec_info" : {
"startTs" : 1584980963008,
"endTs" : 1584980963236
},
"language" : "scala"
},
"language" : "scala",
"source" : [
"case class Linear(weights: Variable, bias: Variable)\r\n",
" extends Module(Seq(weights, bias)) {\r\n",
" override def forward(x: Variable): Variable = {\r\n",
" (x dot weights.t()) + bias\r\n",
" }\r\n",
"}\r\n",
"\r\n",
"object Linear {\r\n",
" def apply(inFeatures: Int, outFeatures: Int): Linear = {\r\n",
" val w = ns.randn(outFeatures, inFeatures) * math.sqrt(2.0 / outFeatures)\r\n",
" val weights = Variable(w)\r\n",
" val b = ns.zeros(1, outFeatures)\r\n",
" val bias = Variable(b)\r\n",
" Linear(weights, bias)\r\n",
" }\r\n",
"}"
],
"outputs" : [
]
},
{
"cell_type" : "markdown",
"execution_count" : 8,
"metadata" : {
"language" : "text"
},
"language" : "text",
"source" : [
"#### Extend Variable to act like a Module\n",
"\n"
],
"outputs" : [
]
},
{
"cell_type" : "code",
"execution_count" : 9,
"metadata" : {
"cell.metadata.exec_info" : {
"startTs" : 1584980963239,
"endTs" : 1584980963372
},
"language" : "scala"
},
"language" : "scala",
"source" : [
"implicit def moduleApply[T <: Module](m: T): (Variable) => Variable = m.forward"
],
"outputs" : [
]
},
{
"cell_type" : "markdown",
"execution_count" : 10,
"metadata" : {
"language" : "text"
},
"language" : "text",
"source" : [
"### Optimizer\n",
"\n"
],
"outputs" : [
]
},
{
"cell_type" : "code",
"execution_count" : 11,
"metadata" : {
"cell.metadata.exec_info" : {
"startTs" : 1584980963374,
"endTs" : 1584980963492
},
"language" : "scala"
},
"language" : "scala",
"source" : [
"abstract class Optimizer(parameters: Seq[Variable]) {\r\n",
" def step(epoch: Int)\r\n",
" def zeroGrad(): Unit = parameters.foreach(p => p.g := 0)\r\n",
"}"
],
"outputs" : [
]
},
{
"cell_type" : "markdown",
"execution_count" : 12,
"metadata" : {
"language" : "text"
},
"language" : "text",
"source" : [
"#### Stochastic Gradient Descent \n",
"\n"
],
"outputs" : [
]
},
{
"cell_type" : "code",
"execution_count" : 13,
"metadata" : {
"cell.metadata.exec_info" : {
"startTs" : 1584980963495,
"endTs" : 1584980963928
},
"language" : "scala"
},
"language" : "scala",
"source" : [
"case class SGD(parameters: Seq[Variable], learningRate: Double) extends Optimizer(parameters) {\r\n",
" override def step(epoch: Int): Unit =\r\n",
" parameters.foreach { p =>\r\n",
" p.data -= learningRate * p.g\r\n",
" } \r\n",
"}"
],
"outputs" : [
]
},
{
"cell_type" : "markdown",
"execution_count" : 14,
"metadata" : {
"language" : "text"
},
"language" : "text",
"source" : [
"### Softmax Loss function\n",
"\n"
],
"outputs" : [
]
},
{
"cell_type" : "code",
"execution_count" : 15,
"metadata" : {
"cell.metadata.exec_info" : {
"startTs" : 1584980963930,
"endTs" : 1584980964184
},
"language" : "scala"
},
"language" : "scala",
"source" : [
"case class SoftmaxLoss(actual: Variable, target: Variable) extends Function {\r\n",
" val x: Tensor = actual.data\r\n",
" val y: Tensor = target.data.T\r\n",
"\r\n",
" val shiftedLogits: Tensor = x - ns.max(x, axis = 1)\r\n",
" val z: Tensor = ns.sum(ns.exp(shiftedLogits), axis = 1)\r\n",
" val logProbs: Tensor = shiftedLogits - ns.log(z)\r\n",
" val n: Int = x.shape.head\r\n",
" val loss: Double = -ns.sum(logProbs(ns.arange(n), y)) / n\r\n",
"\r\n",
" override def forward(): Variable = Variable(Tensor(loss), Some(this))\r\n",
"\r\n",
" override def backward(gradOutput: Tensor /* not used */ ): Unit = {\r\n",
" val dx = ns.exp(logProbs)\r\n",
" dx(ns.arange(n), y) -= 1\r\n",
" dx /= n\r\n",
"\r\n",
" actual.backward(dx)\r\n",
" }\r\n",
"}"
],
"outputs" : [
]
},
{
"cell_type" : "markdown",
"execution_count" : 16,
"metadata" : {
"language" : "text"
},
"language" : "text",
"source" : [
"### Data Loader\n",
"\n"
],
"outputs" : [
]
},
{
"cell_type" : "code",
"execution_count" : 17,
"metadata" : {
"cell.metadata.exec_info" : {
"startTs" : 1584980964186,
"endTs" : 1584980964312
},
"language" : "scala"
},
"language" : "scala",
"source" : [
"trait DataLoader extends Iterable[(Variable, Variable)] {\r\n",
" def numSamples: Int\r\n",
" def numBatches: Int\r\n",
" def mode: String\r\n",
" def iterator: Iterator[(Variable, Variable)]\r\n",
"}"
],
"outputs" : [
]
},
{
"cell_type" : "markdown",
"execution_count" : 18,
"metadata" : {
"language" : "text"
},
"language" : "text",
"source" : [
"#### Fashion MNIST data loader\n",
"\n"
],
"outputs" : [
]
},
{
"cell_type" : "code",
"execution_count" : 19,
"metadata" : {
"cell.metadata.exec_info" : {
"startTs" : 1584980964315,
"endTs" : 1584980965083
},
"language" : "scala"
},
"language" : "scala",
"source" : [
"import scala.io.Source\r\n",
"import scala.language.postfixOps\r\n",
"import scala.util.Random\r\n",
"\r\n",
"class FashionMnistDataLoader(val mode: String,\r\n",
" miniBatchSize: Int,\r\n",
" take: Option[Int] = None,\r\n",
" seed: Long = 231)\r\n",
" extends DataLoader {\r\n",
"\r\n",
" Random.setSeed(seed)\r\n",
"\r\n",
" val file: String = mode match {\r\n",
" case \"train\" => \"notebooks/data/fashionmnist/fashion-mnist_train.csv\"\r\n",
" case \"valid\" => \"notebooks/data/fashionmnist/fashion-mnist_test.csv\"\r\n",
" }\r\n",
"\r\n",
" val lines: List[String] = {\r\n",
" val src = Source.fromFile(file)\r\n",
" val lines = src.getLines().toList\r\n",
" src.close()\r\n",
" Random.shuffle(lines.tail) // skip header\r\n",
" }\r\n",
"\r\n",
" val numFeatures = 784\r\n",
" val numEntries: Int = lines.length\r\n",
"\r\n",
" val numSamples: Int = take match {\r\n",
" case Some(n) => math.min(n, numEntries)\r\n",
" case None => numEntries\r\n",
" }\r\n",
"\r\n",
" val numBatches: Int =\r\n",
" (numSamples / miniBatchSize) +\r\n",
" (if (numSamples % miniBatchSize == 0) 0 else 1)\r\n",
"\r\n",
" val data: Seq[(Variable, Variable)] = lines\r\n",
" .take(take.getOrElse(numSamples))\r\n",
" .sliding(miniBatchSize, miniBatchSize)\r\n",
" .map { lines =>\r\n",
" val batchSize = lines.length\r\n",
"\r\n",
" val xs = Array.fill[Float](numFeatures * batchSize)(elem = 0)\r\n",
" val ys = Array.fill[Float](batchSize)(elem = 0)\r\n",
"\r\n",
" lines.zipWithIndex.foreach {\r\n",
" case (line, lnr) =>\r\n",
" val tokens = line.split(\",\")\r\n",
" ys(lnr) = tokens.head.toFloat\r\n",
" tokens.tail.zipWithIndex.foreach {\r\n",
" case (sx, i) =>\r\n",
" xs(lnr * numFeatures + i) = sx.toFloat / 255.0f\r\n",
" }\r\n",
" }\r\n",
"\r\n",
" val x = Variable(Tensor(xs).reshape(batchSize, numFeatures))\r\n",
" val y = Variable(Tensor(ys).reshape(batchSize, 1))\r\n",
" (x, y)\r\n",
" }\r\n",
" .toSeq\r\n",
"\r\n",
" lazy val meanImage: Tensor = {\r\n",
" val m = ns.zeros(1, numFeatures)\r\n",
" data.foreach {\r\n",
" case (x, _) =>\r\n",
" m += ns.sum(x.data, axis = 0)\r\n",
" }\r\n",
" m /= numSamples\r\n",
" m\r\n",
" }\r\n",
"\r\n",
" if (mode == \"train\") zeroCenter(meanImage)\r\n",
"\r\n",
" def zeroCenter(meanImage: Tensor): Unit = {\r\n",
" val bcm = broadcastTo(meanImage, miniBatchSize, numFeatures)\r\n",
"\r\n",
" data.foreach {\r\n",
" case (x, _) =>\r\n",
" if (x.data.shape.head == miniBatchSize)\r\n",
" x.data -= bcm\r\n",
" else\r\n",
" x.data -= meanImage\r\n",
" }\r\n",
" }\r\n",
"\r\n",
" def iterator: Iterator[(Variable, Variable)] =\r\n",
" Random.shuffle(data.toIterator)\r\n",
"\r\n",
" def broadcastTo(t: Tensor, shape: Int*): Tensor =\r\n",
" new Tensor(t.array.broadcast(shape: _*))\r\n",
"\r\n",
"}"
],
"outputs" : [
]
},
{
"cell_type" : "markdown",
"execution_count" : 20,
"metadata" : {
"language" : "text"
},
"language" : "text",
"source" : [
"### Learner\n",
"\n"
],
"outputs" : [
]
},
{
"cell_type" : "code",
"execution_count" : 21,
"metadata" : {
"cell.metadata.exec_info" : {
"startTs" : 1584980965085,
"endTs" : 1584980965492
},
"language" : "scala"
},
"language" : "scala",
"source" : [
"import java.util.Locale\r\n",
"\r\n",
"case class Learner(trainingDataLoader: DataLoader,\r\n",
" testDataLoader: DataLoader,\r\n",
" net: Module,\r\n",
" optimizer: Optimizer,\r\n",
" loss: (Variable, Variable) => Variable) {\r\n",
"\r\n",
" Locale.setDefault(Locale.US)\r\n",
"\r\n",
" def fit(numEpochs: Int): Unit = {\r\n",
" (0 until numEpochs) foreach { epoch =>\r\n",
" val t0 = System.nanoTime()\r\n",
" trainingDataLoader.foreach { // for each mini batch\r\n",
" case (x, y) =>\r\n",
" optimizer.zeroGrad()\r\n",
" val yh = net(x)\r\n",
" val l = loss(yh, y)\r\n",
" l.backward()\r\n",
" optimizer.step(epoch) // update parameters using their gradient\r\n",
" }\r\n",
" val t1 = System.nanoTime()\r\n",
" val dur = (t1 - t0) / 1000000\r\n",
" val (ltrn, atrn) = evaluate(trainingDataLoader, net)\r\n",
" val (ltst, atst) = evaluate(testDataLoader, net)\r\n",
"\r\n",
" println(\r\n",
" f\"epoch: $epoch%2d duration: $dur%4dms loss: $ltst%1.4f / $ltrn%1.4f\\taccuracy: $atst%1.4f / $atrn%1.4f\")\r\n",
" }\r\n",
" }\r\n",
"\r\n",
" def evaluate(dl: DataLoader,\r\n",
" net: Module): (Double, Double) = {\r\n",
" val (l, a) =\r\n",
" dl.foldLeft(0.0, 0.0) {\r\n",
" case ((lossAcc, accuracyAcc), (x, y)) =>\r\n",
" val output = net(x)\r\n",
" val guessed = ns.argmax(output.data, axis = 1)\r\n",
" val accuracy = ns.sum(guessed == y.data)\r\n",
" val cost = loss(output, y).data.squeeze()\r\n",
" (lossAcc + cost, accuracyAcc + accuracy)\r\n",
" }\r\n",
" (l / dl.numBatches, a / dl.numSamples)\r\n",
" }\r\n",
"}\r\n"
],
"outputs" : [
]
},
{
"cell_type" : "markdown",
"execution_count" : 22,
"metadata" : {
"language" : "text"
},
"language" : "text",
"source" : [
"### Train a network\n",
"\n"
],
"outputs" : [
]
},
{
"cell_type" : "code",
"execution_count" : 23,
"metadata" : {
"cell.metadata.exec_info" : {
"startTs" : 1584980965495,
"endTs" : 1584981051302
},
"language" : "scala"
},
"language" : "scala",
"source" : [
"val batchSize = 256\r\n",
"\r\n",
"val trainDl = new FashionMnistDataLoader(\"train\", batchSize)\r\n",
"val validDl = new FashionMnistDataLoader(\"valid\", batchSize)\r\n"
],
"outputs" : [
]
},
{
"cell_type" : "code",
"execution_count" : 24,
"metadata" : {
"cell.metadata.exec_info" : {
"startTs" : 1584982453578,
"endTs" : 1584982453747
},
"language" : "scala"
},
"language" : "scala",
"source" : [
"def relu(x: Variable): Variable = x.relu()\r\n",
"\r\n",
"val nn: Module = new Module() {\r\n",
" val fc1 = Linear(784, 100)\r\n",
" val fc2 = Linear(100, 10)\r\n",
" override def forward(x: Variable): Variable = \r\n",
" x ~> fc1 ~> relu ~> fc2\r\n",
"}\r\n",
"\r\n",
"def loss(yHat: Variable, y: Variable): Variable = SoftmaxLoss(yHat, y).forward()\r\n"
],
"outputs" : [
]
},
{
"cell_type" : "code",
"execution_count" : 25,
"metadata" : {
"cell.metadata.exec_info" : {
"startTs" : 1584982454469,
"endTs" : 1584982529232
},
"language" : "scala"
},
"language" : "scala",
"source" : [
"val numEpochs = 10\r\n",
"val sgd = SGD(nn.parameters, .1)\r\n",
"\r\n",
"Learner(trainDl, validDl, nn, sgd, loss).fit(numEpochs)"
],
"outputs" : [
{
"name" : "stdout",
"text" : [
"epoch: 0 duration: 3955ms loss: 1.2004 / 0.5082\taccuracy: 0.6042 / 0.8229\n",
"epoch: 1 duration: 4134ms loss: 1.0216 / 0.4373\taccuracy: 0.6471 / 0.8463\n",
"epoch: 2 duration: 4060ms loss: 0.9192 / 0.4031\taccuracy: 0.6796 / 0.8587\n",
"epoch: 3 duration: 4340ms loss: 0.8451 / 0.3781\taccuracy: 0.7032 / 0.8667\n",
"epoch: 4 duration: 4017ms loss: 0.8047 / 0.3585\taccuracy: 0.7190 / 0.8728\n",
"epoch: 5 duration: 4016ms loss: 0.7962 / 0.3478\taccuracy: 0.7176 / 0.8769\n",
"epoch: 6 duration: 3941ms loss: 0.8497 / 0.3383\taccuracy: 0.6967 / 0.8804\n",
"epoch: 7 duration: 3961ms loss: 0.7662 / 0.3256\taccuracy: 0.7289 / 0.8848\n",
"epoch: 8 duration: 3968ms loss: 0.7495 / 0.3150\taccuracy: 0.7377 / 0.8886\n",
"epoch: 9 duration: 3937ms loss: 0.7482 / 0.3100\taccuracy: 0.7358 / 0.8901\n"
],
"output_type" : "stream"
}
]
},
{
"cell_type" : "code",
"execution_count" : 26,
"metadata" : {
"cell.metadata.exec_info" : {
"startTs" : 1584979939961,
"endTs" : 1584979940011
},
"language" : "scala"
},
"language" : "scala",
"source" : [
],
"outputs" : [
]
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment