Created
March 13, 2023 05:54
-
-
Save yellowflash/897044d499a4f783a1f0c09a952f5357 to your computer and use it in GitHub Desktop.
Automatic differentiation, by differed cached evaluation.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
object AutoDiff { | |
// When we are computing derivative for say f(x) we would like to keep track of dx/dt (t being independent variable) | |
// in-case t == x then diff = 1. | |
// Suppose x is not a scalar ie., R^k then we would like to keep track of partial derivatives on each co-ordinate. | |
// Hence V is a vector space (which is just enough for our case), this would be final Jacobian | |
case class Dual[K, V](value: K, diff: V) | |
// We define a Vectorspace V over the field K, we can reduce the generalization by keeping K == Double and it should mostly work too. | |
trait VectorSpace[K, V]: | |
def zero: V | |
extension (first: V) | |
def +(another: V): V | |
def *(scale: K): V | |
// Single component trivial vector space. | |
given doubleVector: VectorSpace[Double, Double] with | |
def zero = 0 | |
extension (first: Double) | |
def +(another: Double) = first + another | |
def *(scale: Double) = first * scale | |
// We define Dual to be a number or substitutable to number. | |
// We need to generalize numeric to include transcendental numbers too. | |
trait ExtendedNumeric[N] extends Fractional[N]: | |
def sin(x: N): N | |
def cos(x: N): N | |
def log(x: N): N | |
def exp(x: N): N | |
type Num[T] = ExtendedNumeric[T] | |
// We just generalize Doubles here | |
given doubleIsExtendedNumeric: Num[Double] with Numeric.DoubleIsFractional with Ordering.Double.IeeeOrdering with | |
def sin(x: Double) = math.sin(x) | |
def cos(x: Double) = math.cos(x) | |
def log(x: Double) = math.log(x) | |
def exp(x: Double) = math.exp(x) | |
import Fractional.Implicits.{given} | |
// Dual is an extended number with all glory, given K is an extended number and V is a vector space. | |
given numericDual[K, V](using n: Num[K], v: VectorSpace[K, V]) : Num[Dual[K, V]] with | |
def minusOne: K = n.negate(n.one) | |
def plus(x: Dual[K, V], y: Dual[K, V]): Dual[K, V] = Dual(x.value + y.value, x.diff + y.diff) | |
def minus(x: Dual[K, V], y: Dual[K, V]): Dual[K, V] = Dual(x.value - y.value, x.diff + (y.diff * minusOne)) | |
def times(x: Dual[K, V], y: Dual[K, V]): Dual[K, V] = Dual(x.value * y.value, x.diff * y.value + y.diff * x.value) | |
def negate(x: Dual[K, V]): Dual[K, V] = Dual(x.value, x.diff * minusOne) | |
def div(x: Dual[K, V], y: Dual[K, V]) = Dual(x.value / y.value, (x.diff * y.value + y.diff * minusOne * x.value) * (n.one / (y.value * y.value))) | |
def fromInt(x: Int): Dual[K, V] = Dual(n.fromInt(x), v.zero) | |
def parseString(str: String): Option[Dual[K, V]] = n.parseString(str).map(Dual(_, v.zero)) | |
def toInt(x: Dual[K, V]): Int = n.toInt(x.value) | |
def toLong(x: Dual[K, V]): Long = n.toLong(x.value) | |
def toFloat(x: Dual[K, V]): Float = n.toFloat(x.value) | |
def toDouble(x: Dual[K, V]): Double = n.toDouble(x.value) | |
def compare(left: Dual[K, V], right: Dual[K, V]) = n.compare(left.value, right.value) | |
def sin(x: Dual[K, V]): Dual[K, V] = Dual(n.sin(x.value), x.diff * n.cos(x.value)) | |
def cos(x: Dual[K, V]): Dual[K, V] = Dual(n.cos(x.value), x.diff * (minusOne * n.sin(x.value))) | |
def log(x: Dual[K, V]): Dual[K, V] = Dual(n.log(x.value), x.diff * (n.one / x.value)) | |
def exp(x: Dual[K, V]): Dual[K, V] = Dual(n.exp(x.value), x.diff * n.exp(x.value)) | |
// We could already compute derivatives like | |
// diff(f) = x => f(dual(x, 1)).diff | |
// In-order to optimally calculate derivatives with multiple independent variables, we would delay computing the Jacobian like this. | |
sealed trait Delta | |
case object Zero extends Delta | |
case class OneHot(i: Int) extends Delta | |
case class Scale(scale: Double, vector: Delta) extends Delta | |
case class Add(left: Delta, right: Delta) extends Delta | |
case class Var(id: Long) extends Delta | |
case class Let(id: Long, value: Delta, block: Delta) extends Delta | |
// Delta is a vector space | |
given VectorSpace[Double, Delta] with | |
def zero: Delta = Zero | |
extension (first: Delta) | |
def +(another: Delta): Delta = Add(first, another) | |
def *(scale: Double): Delta = Scale(scale, first) | |
// Scala doesn't auto calls fromInt it looks like not sure why they are there in first place if it's not called | |
// Sigh !!! | |
given Conversion[Double, Dual[Double, Delta]] with | |
override def apply(x: Double): Dual[Double, Delta] = Dual(x, Zero) | |
type DeltaMap = Map[Long, Array[Double]] | |
def eval(dimensions: Int, delta: Delta, mappings: DeltaMap): Array[Double] = delta match | |
case Zero => Array.fill(dimensions)(0.0) | |
case OneHot(i) => { | |
val array = Array.fill(dimensions)(0.0) | |
array.update(i, 1) | |
array | |
} | |
case Scale(scale, vector) => eval(dimensions, vector, mappings).map(_ * scale) | |
case Add(left, right) => { | |
val l = eval(dimensions, left, mappings) | |
val r = eval(dimensions, right, mappings) | |
val result = Array.copyOf[Double](l, dimensions) | |
for (i <- 0.until(dimensions)) result.update(i, result(i) + r(i)) | |
result | |
} | |
case Var(id) => mappings(id) // If we clone here we can happily mutate stuff above. | |
case Let(id, value, block) => eval(dimensions, block, mappings + (id -> eval(dimensions, value, mappings))) | |
def diffV(f: Array[Dual[Double, Delta]] => Array[Dual[Double, Delta]]): Array[Double] => Array[Array[Double]] = { | |
value => { | |
val fresult = f(value.zipWithIndex.map{(v, i) => Dual(v, OneHot(i))}) | |
for { fval <- fresult } yield eval(value.length, fval.diff, Map.empty) | |
} | |
} | |
def diff(f: Dual[Double, Delta] => Dual[Double, Delta]): Double => Double = v => eval(1, f(Dual(v, OneHot(0))).diff, Map.empty)(0) | |
def main(args: Array[String]) = { | |
import scala.language.implicitConversions | |
println(diff(x => x * x + x * 2.0 + 3.0)(2)) | |
diffV{case Array(x, y) => Array(x * x * x + y * 2.0, x * y, x + y)}(Array(2.0, 1.0)) | |
.foreach(row => println(row.map(String.format("%10.2f", _)).mkString)) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment