Skip to content

Instantly share code, notes, and snippets.

@suzuki-navi
Created May 20, 2020 15:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save suzuki-navi/5b7c968c1e556481425dad4b98fbad45 to your computer and use it in GitHub Desktop.
Save suzuki-navi/5b7c968c1e556481425dad4b98fbad45 to your computer and use it in GitHub Desktop.
import scala.util.Random;
import Ordering.Double.IeeeOrdering;
//==================================================================================================
object GradientDescent {
//--------------------------------------------------------------------------------------------------
import Util._;
import Util.CalculationCache.KeyValue;
sealed trait NodeTrait;
implicit def nodeCalculationKeyValue[Z] =
new KeyValue[Node[Z], NodeCalculationTrait, NodeCalculation[Z]]() {
def cast(v: NodeCalculationTrait) = v.asInstanceOf[NodeCalculation[Z]];
}
trait Node[Z] extends NodeTrait {
def output: Z = {
val cache = new CalculationCache[NodeTrait, NodeCalculationTrait]();
output(cache);
}
def lastOutput(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Z = {
val cacheElem = cache.getOrElse(this, {throw new Exception()});
cacheElem.lastOutput;
}
def output(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Z;
def backpropagation(derivative: Z, cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Unit;
def next(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Node[Z];
}
sealed trait NodeCalculationTrait;
trait NodeCalculation[Z] extends NodeCalculationTrait {
private[this] var _outputCounter: Int = 0;
private[this] var _output: Z = _;
private[this] var _backpropagationCounter: Int = 0;
private[this] var _backpropagationValues: List[Z] = Nil;
def lastOutput: Z = {
if (_outputCounter == 0) {
throw new Exception();
}
_output;
}
def output(p: => Z): Z = {
if (_outputCounter == 0) {
_output = p;
}
_outputCounter += 1;
_output;
}
def backpropagation(derivative: Z)(p: Z => Unit)(implicit shape: Shape[Z]): Unit = {
_backpropagationCounter += 1;
_backpropagationValues = derivative :: _backpropagationValues;
if (_backpropagationCounter == _outputCounter) {
p(shape.sum(_backpropagationValues));
}
}
}
trait Shape[Z] {
def sum(it: Iterable[Z]): Z;
}
//--------------------------------------------------------------------------------------------------
case class Constant[Z](value: Z)
(implicit shape: Shape[Z])
extends Node[Z] {
def output(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Z = {
value;
}
override def lastOutput(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Z = {
value;
}
def backpropagation(derivative: Z, cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Unit = {
// nothing
}
def next(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Constant[Z] = this;
}
//--------------------------------------------------------------------------------------------------
implicit def parameterNodeCalculationKeyValue[Z] =
new KeyValue[Parameter[Z], NodeCalculationTrait, ParameterNodeCalculation[Z]]() {
def cast(v: NodeCalculationTrait) = v.asInstanceOf[ParameterNodeCalculation[Z]];
}
case class Parameter[Z](param: (Z, Z, Z), deltaCalculator: DeltaCalculator[Z], id: Int)
(implicit shape: Shape[Z])
extends Node[Z] {
def value = param._1;
def output(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Z = {
val cacheElem = cache.getOrElse(this, new ParameterNodeCalculation[Z]);
cacheElem.output {
param._1;
}
}
override def lastOutput(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Z = {
param._1;
}
def backpropagation(derivative: Z, cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Unit = {
val cacheElem = cache.get(this);
cacheElem.backpropagation(derivative) { derivative =>
cacheElem.derivative = derivative;
}
}
def next(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Parameter[Z] = {
val cacheElem = cache.get(this);
cacheElem.next match {
case None =>
val n = Parameter(deltaCalculator.calcDelta(param._1, param._2, param._3, cacheElem.derivative),
deltaCalculator, id);
cacheElem.next = Some(n);
n;
case Some(n) =>
n;
}
}
}
object Parameter {
private[this] var _counter: Int = 0;
def create[Z](param: (Z, Z, Z), deltaCalculator: DeltaCalculator[Z])
(implicit shape: Shape[Z]): Parameter[Z] = {
_counter += 1;
val id = _counter;
Parameter(param, deltaCalculator, id);
}
def create[Z](initializer: Initializer[Z], deltaCalculator: DeltaCalculator[Z])
(implicit shape: Shape[Z]): Parameter[Z] = {
create(initializer.initialize(), deltaCalculator);
}
}
class ParameterNodeCalculation[Z] extends NodeCalculation[Z] {
var derivative: Z = _;
var next: Option[Parameter[Z]] = None;
}
//--------------------------------------------------------------------------------------------------
trait AddShape[X, Y, Z] {
def add(x: X, y: Y): Z;
def backX(derivative: Z): X;
def backY(derivative: Z): Y;
}
implicit def addNodeCalculationKeyValue[X, Y, Z] =
new KeyValue[Add[X, Y, Z], NodeCalculationTrait, AddNodeCalculation[X, Y, Z]]() {
def cast(v: NodeCalculationTrait) = v.asInstanceOf[AddNodeCalculation[X, Y, Z]];
}
case class Add[X, Y, Z](x_node: Node[X], y_node: Node[Y])
(implicit addShape: AddShape[X, Y, Z], outputShape: Shape[Z])
extends Node[Z] {
def output(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Z = {
val cacheElem = cache.getOrElse(this, new AddNodeCalculation[X, Y, Z]);
cacheElem.output {
val x_value = x_node.output(cache);
val y_value = y_node.output(cache);
addShape.add(x_value, y_value);
}
}
def backpropagation(derivative: Z, cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Unit = {
val cacheElem = cache.get(this);
cacheElem.backpropagation(derivative) { derivative =>
x_node.backpropagation(addShape.backX(derivative), cache);
y_node.backpropagation(addShape.backY(derivative), cache);
}
}
def next(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Add[X, Y, Z] = {
Add(x_node.next(cache), y_node.next(cache));
}
}
class AddNodeCalculation[X, Y, Z] extends NodeCalculation[Z] {
}
//--------------------------------------------------------------------------------------------------
trait MultiplyShape[X, Y, Z] {
def multiply(x: X, y: Y): Z;
def backX(y: Y, derivative: Z): X;
def backY(x: X, derivative: Z): Y;
}
implicit def multiplyNodeCalculationKeyValue[X, Y, Z] =
new KeyValue[Multiply[X, Y, Z], NodeCalculationTrait, MultiplyNodeCalculation[X, Y, Z]]() {
def cast(v: NodeCalculationTrait) = v.asInstanceOf[MultiplyNodeCalculation[X, Y, Z]];
}
case class Multiply[X, Y, Z](x_node: Node[X], y_node: Node[Y])
(implicit multiplyShape: MultiplyShape[X, Y, Z], outputShape: Shape[Z])
extends Node[Z] {
def output(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Z = {
val cacheElem = cache.getOrElse(this, new MultiplyNodeCalculation[X, Y, Z]);
cacheElem.output {
cacheElem.x_value = x_node.output(cache);
cacheElem.y_value = y_node.output(cache);
multiplyShape.multiply(cacheElem.x_value, cacheElem.y_value);
}
}
def backpropagation(derivative: Z, cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Unit = {
val cacheElem = cache.get(this);
cacheElem.backpropagation(derivative) { derivative =>
x_node.backpropagation(multiplyShape.backX(cacheElem.y_value, derivative), cache);
y_node.backpropagation(multiplyShape.backY(cacheElem.x_value, derivative), cache);
}
}
def next(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Multiply[X, Y, Z] = {
Multiply(x_node.next(cache), y_node.next(cache));
}
}
class MultiplyNodeCalculation[X, Y, Z] extends NodeCalculation[Z] {
var x_value: X = _;
var y_value: Y = _;
}
//--------------------------------------------------------------------------------------------------
trait DivideShape[X, Y, Z] {
def divide(x: X, y: Y): Z;
def backX(y: Y, derivative: Z): X;
def backY(x: X, y: Y, derivative: Z): Y;
}
implicit def divideNodeCalculationKeyValue[X, Y, Z] =
new KeyValue[Divide[X, Y, Z], NodeCalculationTrait, DivideNodeCalculation[X, Y, Z]]() {
def cast(v: NodeCalculationTrait) = v.asInstanceOf[DivideNodeCalculation[X, Y, Z]];
}
case class Divide[X, Y, Z](x_node: Node[X], y_node: Node[Y])
(implicit divideShape: DivideShape[X, Y, Z], outputShape: Shape[Z])
extends Node[Z] {
def output(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Z = {
val cacheElem = cache.getOrElse(this, new DivideNodeCalculation[X, Y, Z]);
cacheElem.output {
cacheElem.x_value = x_node.output(cache);
cacheElem.y_value = y_node.output(cache);
divideShape.divide(cacheElem.x_value, cacheElem.y_value);
}
}
def backpropagation(derivative: Z, cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Unit = {
val cacheElem = cache.get(this);
cacheElem.backpropagation(derivative) { derivative =>
x_node.backpropagation(divideShape.backX(cacheElem.y_value, derivative), cache);
y_node.backpropagation(divideShape.backY(cacheElem.x_value, cacheElem.y_value, derivative), cache);
}
}
def next(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Divide[X, Y, Z] = {
Divide(x_node.next(cache), y_node.next(cache));
}
}
class DivideNodeCalculation[X, Y, Z] extends NodeCalculation[Z] {
var x_value: X = _;
var y_value: Y = _;
}
//--------------------------------------------------------------------------------------------------
trait SquareShape[X, Z] {
def square(x: X): Z;
def backX(x: X, derivative: Z): X;
}
implicit def squareNodeCalculationKeyValue[X, Z] =
new KeyValue[Square[X, Z], NodeCalculationTrait, SquareNodeCalculation[X, Z]]() {
def cast(v: NodeCalculationTrait) = v.asInstanceOf[SquareNodeCalculation[X, Z]];
}
case class Square[X, Z](x_node: Node[X])
(implicit squareShape: SquareShape[X, Z], outputShape: Shape[Z])
extends Node[Z] {
def output(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Z = {
val cacheElem = cache.getOrElse(this, new SquareNodeCalculation[X, Z]);
cacheElem.output {
cacheElem.x_value = x_node.output(cache);
squareShape.square(cacheElem.x_value);
}
}
def backpropagation(derivative: Z, cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Unit = {
val cacheElem = cache.get(this);
cacheElem.backpropagation(derivative) { derivative =>
x_node.backpropagation(squareShape.backX(cacheElem.x_value, derivative), cache);
}
}
def next(cache: CalculationCache[NodeTrait, NodeCalculationTrait]): Square[X, Z] = {
Square(x_node.next(cache));
}
}
class SquareNodeCalculation[X, Z] extends NodeCalculation[Z] {
var x_value: X = _;
}
//--------------------------------------------------------------------------------------------------
trait Initializer[Z] {
def initialize(): (Z, Z, Z);
}
case class SingleRandomInitializer(center: Double, width: Double,
derivativeWidth: Double, deltaWidth: Double) extends Initializer[Double] {
type A = Double;
private[this] val start = center - width;
private[this] val derivativeStart = - derivativeWidth;
private[this] val deltaStart = - deltaWidth;
def initialize(): (A, A, A) = {
(
Random.nextDouble() * 2.0 * width + start,
Random.nextDouble() * 2.0 * derivativeWidth + derivativeStart,
Random.nextDouble() * 2.0 * deltaWidth + deltaStart,
);
}
}
//--------------------------------------------------------------------------------------------------
trait DeltaCalculator[Z] {
def calcDelta(value: Z, prevDerivative: Z, prevDelta: Z, derivative: Z): (Z, Z, Z);
}
case class SingleSimpleDeltaCalculator(eta: Double) extends DeltaCalculator[Double] {
type A = Double;
def calcDelta(value: A, prevDerivative: A, prevDelta: A, derivative: A): (A, A, A) = {
val nextDelta = derivative * eta;
val nextValue = value + nextDelta;
(nextValue, derivative, nextDelta);
}
}
case class SingleBetterDeltaCalculator() extends DeltaCalculator[Double] {
type A = Double;
def calcDelta(value: A, prevDerivative: A, prevDelta: A, derivative: A): (A, A, A) = {
val nextDelta = if (prevDelta * prevDerivative <= 0.0) {
derivative * 0.01;
} else {
val a1 = derivative / prevDerivative;
val a2 = 2.0 / (1.0 + Math.exp(1.09861229 * (1 - 3 * a1))) - 0.5;
// https://www.wolframalpha.com/input/?i=plot+2+%2F+%281+%2B+exp%281.1+*+%281+-+3x%29%29%29+-+1%2F2+from+x%3D-5+to+%2B5
val a3 = a2 * prevDelta;
a3;
}
val nextValue = value + nextDelta;
(nextValue, derivative, nextDelta);
}
}
//--------------------------------------------------------------------------------------------------
implicit object SingleShape extends Shape[Double] {
type Z = Double;
def sum(it: Iterable[Z]): Z = {
it.foldLeft[Z](0.0) { (s, e) => s + e }
}
}
implicit object ArrayShape extends Shape[Array[Double]] {
type Z = Array[Double];
def sum(it: Iterable[Z]): Z = {
it.reduceLeft { (a, b) => (0 until a.size).toArray.map(i => a(i) + b(i)) }
}
}
//--------------------------------------------------------------------------------------------------
implicit object SingleAddShape extends AddShape[Double, Double, Double] {
type X = Double;
type Y = Double;
type Z = Double;
def add(x: X, y: Y): Z = x + y;
def backX(derivative: Z): X = derivative;
def backY(derivative: Z): Y = derivative;
}
implicit object ArrayBroadcastAddShape extends AddShape[Array[Double], Double, Array[Double]] {
type X = Array[Double];
type Y = Double;
type Z = Array[Double];
def add(x: X, y: Y): Z = x.map(_ + y);
def backX(derivative: Z): X = derivative;
def backY(derivative: Z): Y = derivative.sum;
}
implicit object ArrayAddShape extends AddShape[Array[Double], Array[Double], Array[Double]] {
type X = Array[Double];
type Y = Array[Double];
type Z = Array[Double];
def add(x: X, y: Y): Z = (0 until x.size).toArray.map(i => x(i) + y(i));
def backX(derivative: Z): X = derivative;
def backY(derivative: Z): Y = derivative;
}
//--------------------------------------------------------------------------------------------------
implicit object SingleMultiplyShape extends MultiplyShape[Double, Double, Double] {
type X = Double;
type Y = Double;
type Z = Double;
def multiply(x: X, y: Y): Z = x * y;
def backX(y: Y, derivative: Z): X = y * derivative;
def backY(x: X, derivative: Z): Y = x * derivative;
}
implicit object ArrayBroadcastMultiplyShape extends MultiplyShape[Array[Double], Double, Array[Double]] {
type X = Array[Double];
type Y = Double;
type Z = Array[Double];
def multiply(x: X, y: Y): Z = x.map(x => x * y);
def backX(y: Y, derivative: Z): X = derivative.map(z => y * z);
def backY(x: X, derivative: Z): Y = (0 until x.size).map(i => x(i) * derivative(i)).sum;
}
//--------------------------------------------------------------------------------------------------
implicit object SingleDivideShape extends DivideShape[Double, Double, Double] {
type X = Double;
type Y = Double;
type Z = Double;
def divide(x: X, y: Y): Z = x / y;
def backX(y: Y, derivative: Z): X = derivative / y;
def backY(x: X, y: Y, derivative: Z): Y = - x * derivative / (y * y);
}
implicit object ArrayBroadcastDivideShape extends DivideShape[Array[Double], Double, Array[Double]] {
type X = Array[Double];
type Y = Double;
type Z = Array[Double];
def divide(x: X, y: Y): Z = x.map(x => x / y);
def backX(y: Y, derivative: Z): X = derivative.map(d => d / y);
def backY(x: X, y: Y, derivative: Z): Y = (0 until x.size).map(i => - x(i) * derivative(i) / (y * y)).sum;
}
//--------------------------------------------------------------------------------------------------
implicit object SingleSquareShape extends SquareShape[Double, Double] {
type X = Double;
type Z = Double;
def square(x: X): Z = x * x;
def backX(x: X, derivative: Z): X = 2.0 * x * derivative;
}
implicit object ArraySquareShape extends SquareShape[Array[Double], Array[Double]] {
type X = Array[Double];
type Z = Array[Double];
def square(x: X): Z = x.map(x => x * x);
def backX(x: X, derivative: Z): X = (0 until x.size).toArray.map(i => 2.0 * x(i) * derivative(i));
}
//--------------------------------------------------------------------------------------------------
}
//==================================================================================================
object Util {
class CalculationCache[A, B] {
import CalculationCache._;
private[this] var cache: Map[A, B] = Map.empty;
def get[AE <: A, BE <: B](key: AE)(implicit kv: KeyValue[AE, B, BE]): BE = {
if (cache.contains(key)) {
kv.cast(cache.get(key).get);
} else {
throw new NoSuchElementException();
}
}
def getBase(key: A): B = {
if (cache.contains(key)) {
cache.get(key).get;
} else {
throw new NoSuchElementException();
}
}
def getOrElse[AE <: A, BE <: B](key: AE, p: => BE)(implicit kv: KeyValue[AE, B, BE]): BE = {
if (cache.contains(key)) {
kv.cast(cache.get(key).get);
} else {
val v = p;
cache = cache + (key -> v);
v;
}
}
}
object CalculationCache {
trait KeyValue[AE, B, BE] {
def cast(v: B): BE;
}
}
}
//==================================================================================================
import GradientDescent._;
import Util._;
def main1(initialNodes: (Node[Double], Parameter[Double]), loopCount: Int): Unit = {
val finalNode = (0 until loopCount).foldLeft(initialNodes) { (nodes, i) =>
val cache = new CalculationCache[NodeTrait, NodeCalculationTrait]();
val param1 = nodes._2.value;
val output = nodes._1.output(cache);
println("%d %+7.4f %+7.4f".format(i, output, param1));
nodes._1.backpropagation(-1.0, cache);
(nodes._1.next(cache), nodes._2.next(cache));
}
}
def main1_1(deltaCalculator: DeltaCalculator[Double]): Unit = {
val initialNodes = {
val param1 = Parameter.create(SingleRandomInitializer(0.0, 1.0, 1.0, 1.0), deltaCalculator);
val node1 = Square(param1);
(node1, param1);
}
main1(initialNodes, 10);
}
def main1_1_1(): Unit = {
main1_1(SingleSimpleDeltaCalculator(0.2));
}
def main1_1_3(): Unit = {
main1_1(SingleBetterDeltaCalculator());
}
def main1_2(deltaCalculator: DeltaCalculator[Double]): Unit = {
val initialNodes = {
val param1 = Parameter.create(SingleRandomInitializer(0.0, 1.0, 1.0, 1.0), deltaCalculator);
val node1 = Multiply(Square(param1), Constant(1000.0));
(node1, param1);
}
main1(initialNodes, 10);
}
def main1_2_1(): Unit = {
main1_2(SingleSimpleDeltaCalculator(0.2));
}
def main1_2_2(): Unit = {
main1_2(SingleSimpleDeltaCalculator(0.0002));
}
def main1_2_3(): Unit = {
main1_2(SingleBetterDeltaCalculator());
}
def main1_3(deltaCalculator: DeltaCalculator[Double]): Unit = {
val initialNodes = {
val param1 = Parameter.create(SingleRandomInitializer(0.0, 1.0, 1.0, 1.0), deltaCalculator);
val node1 = Divide(Constant(-1.0), Add(Square(param1), Constant(0.001)));
(node1, param1);
}
main1(initialNodes, 20);
}
def main1_3_1(): Unit = {
main1_3(SingleSimpleDeltaCalculator(0.2));
}
def main1_3_3(): Unit = {
main1_3(SingleBetterDeltaCalculator());
}
main1_1_1();
main1_2_1();
main1_2_2();
main1_1_3();
main1_2_3();
main1_3_1();
main1_3_3();
//==================================================================================================
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment