Created
May 20, 2020 15:01
-
-
Save suzuki-navi/5b7c968c1e556481425dad4b98fbad45 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import scala.util.Random; | |
import Ordering.Double.IeeeOrdering; | |
//================================================================================================== | |
object GradientDescent { | |
//-------------------------------------------------------------------------------------------------- | |
import Util._; | |
import Util.CalculationCache.KeyValue; | |
sealed trait NodeTrait; | |
implicit def nodeCalculationKeyValue[Z] = | |
new KeyValue[Node[Z], NodeCalculationTrait, NodeCalculation[Z]]() { | |
def cast(v: NodeCalculationTrait) = v.asInstanceOf[NodeCalculation[Z]]; | |
} | |
trait Node[Z] extends NodeTrait { | |
def output: Z = { | |
val cache = new CalculationCache[NodeTrait, NodeCalculationTrait](); | |
output(cache); | |
} | |
def lastOutput(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Z = { | |
val cacheElem = cache.getOrElse(this, {throw new Exception()}); | |
cacheElem.lastOutput; | |
} | |
def output(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Z; | |
def backpropagation(derivative: Z, cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Unit; | |
def next(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Node[Z]; | |
} | |
sealed trait NodeCalculationTrait; | |
trait NodeCalculation[Z] extends NodeCalculationTrait { | |
private[this] var _outputCounter: Int = 0; | |
private[this] var _output: Z = _; | |
private[this] var _backpropagationCounter: Int = 0; | |
private[this] var _backpropagationValues: List[Z] = Nil; | |
def lastOutput: Z = { | |
if (_outputCounter == 0) { | |
throw new Exception(); | |
} | |
_output; | |
} | |
def output(p: => Z): Z = { | |
if (_outputCounter == 0) { | |
_output = p; | |
} | |
_outputCounter += 1; | |
_output; | |
} | |
def backpropagation(derivative: Z)(p: Z => Unit)(implicit shape: Shape[Z]): Unit = { | |
_backpropagationCounter += 1; | |
_backpropagationValues = derivative :: _backpropagationValues; | |
if (_backpropagationCounter == _outputCounter) { | |
p(shape.sum(_backpropagationValues)); | |
} | |
} | |
} | |
trait Shape[Z] { | |
def sum(it: Iterable[Z]): Z; | |
} | |
//-------------------------------------------------------------------------------------------------- | |
case class Constant[Z](value: Z) | |
(implicit shape: Shape[Z]) | |
extends Node[Z] { | |
def output(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Z = { | |
value; | |
} | |
override def lastOutput(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Z = { | |
value; | |
} | |
def backpropagation(derivative: Z, cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Unit = { | |
// nothing | |
} | |
def next(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Constant[Z] = this; | |
} | |
//-------------------------------------------------------------------------------------------------- | |
implicit def parameterNodeCalculationKeyValue[Z] = | |
new KeyValue[Parameter[Z], NodeCalculationTrait, ParameterNodeCalculation[Z]]() { | |
def cast(v: NodeCalculationTrait) = v.asInstanceOf[ParameterNodeCalculation[Z]]; | |
} | |
case class Parameter[Z](param: (Z, Z, Z), deltaCalculator: DeltaCalculator[Z], id: Int) | |
(implicit shape: Shape[Z]) | |
extends Node[Z] { | |
def value = param._1; | |
def output(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Z = { | |
val cacheElem = cache.getOrElse(this, new ParameterNodeCalculation[Z]); | |
cacheElem.output { | |
param._1; | |
} | |
} | |
override def lastOutput(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Z = { | |
param._1; | |
} | |
def backpropagation(derivative: Z, cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Unit = { | |
val cacheElem = cache.get(this); | |
cacheElem.backpropagation(derivative) { derivative => | |
cacheElem.derivative = derivative; | |
} | |
} | |
def next(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Parameter[Z] = { | |
val cacheElem = cache.get(this); | |
cacheElem.next match { | |
case None => | |
val n = Parameter(deltaCalculator.calcDelta(param._1, param._2, param._3, cacheElem.derivative), | |
deltaCalculator, id); | |
cacheElem.next = Some(n); | |
n; | |
case Some(n) => | |
n; | |
} | |
} | |
} | |
object Parameter { | |
private[this] var _counter: Int = 0; | |
def create[Z](param: (Z, Z, Z), deltaCalculator: DeltaCalculator[Z]) | |
(implicit shape: Shape[Z]): Parameter[Z] = { | |
_counter += 1; | |
val id = _counter; | |
Parameter(param, deltaCalculator, id); | |
} | |
def create[Z](initializer: Initializer[Z], deltaCalculator: DeltaCalculator[Z]) | |
(implicit shape: Shape[Z]): Parameter[Z] = { | |
create(initializer.initialize(), deltaCalculator); | |
} | |
} | |
class ParameterNodeCalculation[Z] extends NodeCalculation[Z] { | |
var derivative: Z = _; | |
var next: Option[Parameter[Z]] = None; | |
} | |
//-------------------------------------------------------------------------------------------------- | |
trait AddShape[X, Y, Z] { | |
def add(x: X, y: Y): Z; | |
def backX(derivative: Z): X; | |
def backY(derivative: Z): Y; | |
} | |
implicit def addNodeCalculationKeyValue[X, Y, Z] = | |
new KeyValue[Add[X, Y, Z], NodeCalculationTrait, AddNodeCalculation[X, Y, Z]]() { | |
def cast(v: NodeCalculationTrait) = v.asInstanceOf[AddNodeCalculation[X, Y, Z]]; | |
} | |
case class Add[X, Y, Z](x_node: Node[X], y_node: Node[Y]) | |
(implicit addShape: AddShape[X, Y, Z], outputShape: Shape[Z]) | |
extends Node[Z] { | |
def output(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Z = { | |
val cacheElem = cache.getOrElse(this, new AddNodeCalculation[X, Y, Z]); | |
cacheElem.output { | |
val x_value = x_node.output(cache); | |
val y_value = y_node.output(cache); | |
addShape.add(x_value, y_value); | |
} | |
} | |
def backpropagation(derivative: Z, cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Unit = { | |
val cacheElem = cache.get(this); | |
cacheElem.backpropagation(derivative) { derivative => | |
x_node.backpropagation(addShape.backX(derivative), cache); | |
y_node.backpropagation(addShape.backY(derivative), cache); | |
} | |
} | |
def next(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Add[X, Y, Z] = { | |
Add(x_node.next(cache), y_node.next(cache)); | |
} | |
} | |
class AddNodeCalculation[X, Y, Z] extends NodeCalculation[Z] { | |
} | |
//-------------------------------------------------------------------------------------------------- | |
trait MultiplyShape[X, Y, Z] { | |
def multiply(x: X, y: Y): Z; | |
def backX(y: Y, derivative: Z): X; | |
def backY(x: X, derivative: Z): Y; | |
} | |
implicit def multiplyNodeCalculationKeyValue[X, Y, Z] = | |
new KeyValue[Multiply[X, Y, Z], NodeCalculationTrait, MultiplyNodeCalculation[X, Y, Z]]() { | |
def cast(v: NodeCalculationTrait) = v.asInstanceOf[MultiplyNodeCalculation[X, Y, Z]]; | |
} | |
case class Multiply[X, Y, Z](x_node: Node[X], y_node: Node[Y]) | |
(implicit multiplyShape: MultiplyShape[X, Y, Z], outputShape: Shape[Z]) | |
extends Node[Z] { | |
def output(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Z = { | |
val cacheElem = cache.getOrElse(this, new MultiplyNodeCalculation[X, Y, Z]); | |
cacheElem.output { | |
cacheElem.x_value = x_node.output(cache); | |
cacheElem.y_value = y_node.output(cache); | |
multiplyShape.multiply(cacheElem.x_value, cacheElem.y_value); | |
} | |
} | |
def backpropagation(derivative: Z, cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Unit = { | |
val cacheElem = cache.get(this); | |
cacheElem.backpropagation(derivative) { derivative => | |
x_node.backpropagation(multiplyShape.backX(cacheElem.y_value, derivative), cache); | |
y_node.backpropagation(multiplyShape.backY(cacheElem.x_value, derivative), cache); | |
} | |
} | |
def next(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Multiply[X, Y, Z] = { | |
Multiply(x_node.next(cache), y_node.next(cache)); | |
} | |
} | |
class MultiplyNodeCalculation[X, Y, Z] extends NodeCalculation[Z] { | |
var x_value: X = _; | |
var y_value: Y = _; | |
} | |
//-------------------------------------------------------------------------------------------------- | |
trait DivideShape[X, Y, Z] { | |
def divide(x: X, y: Y): Z; | |
def backX(y: Y, derivative: Z): X; | |
def backY(x: X, y: Y, derivative: Z): Y; | |
} | |
implicit def divideNodeCalculationKeyValue[X, Y, Z] = | |
new KeyValue[Divide[X, Y, Z], NodeCalculationTrait, DivideNodeCalculation[X, Y, Z]]() { | |
def cast(v: NodeCalculationTrait) = v.asInstanceOf[DivideNodeCalculation[X, Y, Z]]; | |
} | |
case class Divide[X, Y, Z](x_node: Node[X], y_node: Node[Y]) | |
(implicit divideShape: DivideShape[X, Y, Z], outputShape: Shape[Z]) | |
extends Node[Z] { | |
def output(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Z = { | |
val cacheElem = cache.getOrElse(this, new DivideNodeCalculation[X, Y, Z]); | |
cacheElem.output { | |
cacheElem.x_value = x_node.output(cache); | |
cacheElem.y_value = y_node.output(cache); | |
divideShape.divide(cacheElem.x_value, cacheElem.y_value); | |
} | |
} | |
def backpropagation(derivative: Z, cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Unit = { | |
val cacheElem = cache.get(this); | |
cacheElem.backpropagation(derivative) { derivative => | |
x_node.backpropagation(divideShape.backX(cacheElem.y_value, derivative), cache); | |
y_node.backpropagation(divideShape.backY(cacheElem.x_value, cacheElem.y_value, derivative), cache); | |
} | |
} | |
def next(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Divide[X, Y, Z] = { | |
Divide(x_node.next(cache), y_node.next(cache)); | |
} | |
} | |
class DivideNodeCalculation[X, Y, Z] extends NodeCalculation[Z] { | |
var x_value: X = _; | |
var y_value: Y = _; | |
} | |
//-------------------------------------------------------------------------------------------------- | |
trait SquareShape[X, Z] { | |
def square(x: X): Z; | |
def backX(x: X, derivative: Z): X; | |
} | |
implicit def squareNodeCalculationKeyValue[X, Z] = | |
new KeyValue[Square[X, Z], NodeCalculationTrait, SquareNodeCalculation[X, Z]]() { | |
def cast(v: NodeCalculationTrait) = v.asInstanceOf[SquareNodeCalculation[X, Z]]; | |
} | |
case class Square[X, Z](x_node: Node[X]) | |
(implicit squareShape: SquareShape[X, Z], outputShape: Shape[Z]) | |
extends Node[Z] { | |
def output(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Z = { | |
val cacheElem = cache.getOrElse(this, new SquareNodeCalculation[X, Z]); | |
cacheElem.output { | |
cacheElem.x_value = x_node.output(cache); | |
squareShape.square(cacheElem.x_value); | |
} | |
} | |
def backpropagation(derivative: Z, cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Unit = { | |
val cacheElem = cache.get(this); | |
cacheElem.backpropagation(derivative) { derivative => | |
x_node.backpropagation(squareShape.backX(cacheElem.x_value, derivative), cache); | |
} | |
} | |
def next(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Square[X, Z] = { | |
Square(x_node.next(cache)); | |
} | |
} | |
class SquareNodeCalculation[X, Z] extends NodeCalculation[Z] { | |
var x_value: X = _; | |
} | |
//-------------------------------------------------------------------------------------------------- | |
trait Initializer[Z] { | |
def initialize(): (Z, Z, Z); | |
} | |
case class SingleRandomInitializer(center: Double, width: Double, | |
derivativeWidth: Double, deltaWidth: Double) extends Initializer[Double] { | |
type A = Double; | |
private[this] val start = center - width; | |
private[this] val derivativeStart = - derivativeWidth; | |
private[this] val deltaStart = - deltaWidth; | |
def initialize(): (A, A, A) = { | |
( | |
Random.nextDouble() * 2.0 * width + start, | |
Random.nextDouble() * 2.0 * derivativeWidth + derivativeStart, | |
Random.nextDouble() * 2.0 * deltaWidth + deltaStart, | |
); | |
} | |
} | |
//-------------------------------------------------------------------------------------------------- | |
trait DeltaCalculator[Z] { | |
def calcDelta(value: Z, prevDerivative: Z, prevDelta: Z, derivative: Z): (Z, Z, Z); | |
} | |
case class SingleSimpleDeltaCalculator(eta: Double) extends DeltaCalculator[Double] { | |
type A = Double; | |
def calcDelta(value: A, prevDerivative: A, prevDelta: A, derivative: A): (A, A, A) = { | |
val nextDelta = derivative * eta; | |
val nextValue = value + nextDelta; | |
(nextValue, derivative, nextDelta); | |
} | |
} | |
case class SingleBetterDeltaCalculator() extends DeltaCalculator[Double] { | |
type A = Double; | |
def calcDelta(value: A, prevDerivative: A, prevDelta: A, derivative: A): (A, A, A) = { | |
val nextDelta = if (prevDelta * prevDerivative <= 0.0) { | |
derivative * 0.01; | |
} else { | |
val a1 = derivative / prevDerivative; | |
val a2 = 2.0 / (1.0 + Math.exp(1.09861229 * (1 - 3 * a1))) - 0.5; | |
// https://www.wolframalpha.com/input/?i=plot+2+%2F+%281+%2B+exp%281.1+*+%281+-+3x%29%29%29+-+1%2F2+from+x%3D-5+to+%2B5 | |
val a3 = a2 * prevDelta; | |
a3; | |
} | |
val nextValue = value + nextDelta; | |
(nextValue, derivative, nextDelta); | |
} | |
} | |
//-------------------------------------------------------------------------------------------------- | |
implicit object SingleShape extends Shape[Double] { | |
type Z = Double; | |
def sum(it: Iterable[Z]): Z = { | |
it.foldLeft[Z](0.0) { (s, e) => s + e } | |
} | |
} | |
implicit object ArrayShape extends Shape[Array[Double]] { | |
type Z = Array[Double]; | |
def sum(it: Iterable[Z]): Z = { | |
it.reduceLeft { (a, b) => (0 until a.size).toArray.map(i => a(i) + b(i)) } | |
} | |
} | |
//-------------------------------------------------------------------------------------------------- | |
implicit object SingleAddShape extends AddShape[Double, Double, Double] { | |
type X = Double; | |
type Y = Double; | |
type Z = Double; | |
def add(x: X, y: Y): Z = x + y; | |
def backX(derivative: Z): X = derivative; | |
def backY(derivative: Z): Y = derivative; | |
} | |
implicit object ArrayBroadcastAddShape extends AddShape[Array[Double], Double, Array[Double]] { | |
type X = Array[Double]; | |
type Y = Double; | |
type Z = Array[Double]; | |
def add(x: X, y: Y): Z = x.map(_ + y); | |
def backX(derivative: Z): X = derivative; | |
def backY(derivative: Z): Y = derivative.sum; | |
} | |
implicit object ArrayAddShape extends AddShape[Array[Double], Array[Double], Array[Double]] { | |
type X = Array[Double]; | |
type Y = Array[Double]; | |
type Z = Array[Double]; | |
def add(x: X, y: Y): Z = (0 until x.size).toArray.map(i => x(i) + y(i)); | |
def backX(derivative: Z): X = derivative; | |
def backY(derivative: Z): Y = derivative; | |
} | |
//-------------------------------------------------------------------------------------------------- | |
implicit object SingleMultiplyShape extends MultiplyShape[Double, Double, Double] { | |
type X = Double; | |
type Y = Double; | |
type Z = Double; | |
def multiply(x: X, y: Y): Z = x * y; | |
def backX(y: Y, derivative: Z): X = y * derivative; | |
def backY(x: X, derivative: Z): Y = x * derivative; | |
} | |
implicit object ArrayBroadcastMultiplyShape extends MultiplyShape[Array[Double], Double, Array[Double]] { | |
type X = Array[Double]; | |
type Y = Double; | |
type Z = Array[Double]; | |
def multiply(x: X, y: Y): Z = x.map(x => x * y); | |
def backX(y: Y, derivative: Z): X = derivative.map(z => y * z); | |
def backY(x: X, derivative: Z): Y = (0 until x.size).map(i => x(i) * derivative(i)).sum; | |
} | |
//-------------------------------------------------------------------------------------------------- | |
implicit object SingleDivideShape extends DivideShape[Double, Double, Double] { | |
type X = Double; | |
type Y = Double; | |
type Z = Double; | |
def divide(x: X, y: Y): Z = x / y; | |
def backX(y: Y, derivative: Z): X = derivative / y; | |
def backY(x: X, y: Y, derivative: Z): Y = - x * derivative / (y * y); | |
} | |
implicit object ArrayBroadcastDivideShape extends DivideShape[Array[Double], Double, Array[Double]] { | |
type X = Array[Double]; | |
type Y = Double; | |
type Z = Array[Double]; | |
def divide(x: X, y: Y): Z = x.map(x => x / y); | |
def backX(y: Y, derivative: Z): X = derivative.map(d => d / y); | |
def backY(x: X, y: Y, derivative: Z): Y = (0 until x.size).map(i => - x(i) * derivative(i) / (y * y)).sum; | |
} | |
//-------------------------------------------------------------------------------------------------- | |
implicit object SingleSquareShape extends SquareShape[Double, Double] { | |
type X = Double; | |
type Z = Double; | |
def square(x: X): Z = x * x; | |
def backX(x: X, derivative: Z): X = 2.0 * x * derivative; | |
} | |
implicit object ArraySquareShape extends SquareShape[Array[Double], Array[Double]] { | |
type X = Array[Double]; | |
type Z = Array[Double]; | |
def square(x: X): Z = x.map(x => x * x); | |
def backX(x: X, derivative: Z): X = (0 until x.size).toArray.map(i => 2.0 * x(i) * derivative(i)); | |
} | |
//-------------------------------------------------------------------------------------------------- | |
} | |
//================================================================================================== | |
object Util { | |
class CalculationCache[A, B] { | |
import CalculationCache._; | |
private[this] var cache: Map[A, B] = Map.empty; | |
def get[AE <: A, BE <: B](key: AE)(implicit kv: KeyValue[AE, B, BE]): BE = { | |
if (cache.contains(key)) { | |
kv.cast(cache.get(key).get); | |
} else { | |
throw new NoSuchElementException(); | |
} | |
} | |
def getBase(key: A): B = { | |
if (cache.contains(key)) { | |
cache.get(key).get; | |
} else { | |
throw new NoSuchElementException(); | |
} | |
} | |
def getOrElse[AE <: A, BE <: B](key: AE, p: => BE)(implicit kv: KeyValue[AE, B, BE]): BE = { | |
if (cache.contains(key)) { | |
kv.cast(cache.get(key).get); | |
} else { | |
val v = p; | |
cache = cache + (key -> v); | |
v; | |
} | |
} | |
} | |
object CalculationCache { | |
trait KeyValue[AE, B, BE] { | |
def cast(v: B): BE; | |
} | |
} | |
} | |
//================================================================================================== | |
import GradientDescent._; | |
import Util._; | |
def main1(initialNodes: (Node[Double], Parameter[Double]), loopCount: Int): Unit = { | |
val finalNode = (0 until loopCount).foldLeft(initialNodes) { (nodes, i) => | |
val cache = new CalculationCache[NodeTrait, NodeCalculationTrait](); | |
val param1 = nodes._2.value; | |
val output = nodes._1.output(cache); | |
println("%d %+7.4f %+7.4f".format(i, output, param1)); | |
nodes._1.backpropagation(-1.0, cache); | |
(nodes._1.next(cache), nodes._2.next(cache)); | |
} | |
} | |
def main1_1(deltaCalculator: DeltaCalculator[Double]): Unit = { | |
val initialNodes = { | |
val param1 = Parameter.create(SingleRandomInitializer(0.0, 1.0, 1.0, 1.0), deltaCalculator); | |
val node1 = Square(param1); | |
(node1, param1); | |
} | |
main1(initialNodes, 10); | |
} | |
def main1_1_1(): Unit = { | |
main1_1(SingleSimpleDeltaCalculator(0.2)); | |
} | |
def main1_1_3(): Unit = { | |
main1_1(SingleBetterDeltaCalculator()); | |
} | |
def main1_2(deltaCalculator: DeltaCalculator[Double]): Unit = { | |
val initialNodes = { | |
val param1 = Parameter.create(SingleRandomInitializer(0.0, 1.0, 1.0, 1.0), deltaCalculator); | |
val node1 = Multiply(Square(param1), Constant(1000.0)); | |
(node1, param1); | |
} | |
main1(initialNodes, 10); | |
} | |
def main1_2_1(): Unit = { | |
main1_2(SingleSimpleDeltaCalculator(0.2)); | |
} | |
def main1_2_2(): Unit = { | |
main1_2(SingleSimpleDeltaCalculator(0.0002)); | |
} | |
def main1_2_3(): Unit = { | |
main1_2(SingleBetterDeltaCalculator()); | |
} | |
def main1_3(deltaCalculator: DeltaCalculator[Double]): Unit = { | |
val initialNodes = { | |
val param1 = Parameter.create(SingleRandomInitializer(0.0, 1.0, 1.0, 1.0), deltaCalculator); | |
val node1 = Divide(Constant(-1.0), Add(Square(param1), Constant(0.001))); | |
(node1, param1); | |
} | |
main1(initialNodes, 20); | |
} | |
def main1_3_1(): Unit = { | |
main1_3(SingleSimpleDeltaCalculator(0.2)); | |
} | |
def main1_3_3(): Unit = { | |
main1_3(SingleBetterDeltaCalculator()); | |
} | |
main1_1_1(); | |
main1_2_1(); | |
main1_2_2(); | |
main1_1_3(); | |
main1_2_3(); | |
main1_3_1(); | |
main1_3_3(); | |
//================================================================================================== | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment