Skip to content

Instantly share code, notes, and snippets.

@yellowflash
Created March 13, 2023 05:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yellowflash/897044d499a4f783a1f0c09a952f5357 to your computer and use it in GitHub Desktop.
Save yellowflash/897044d499a4f783a1f0c09a952f5357 to your computer and use it in GitHub Desktop.
Automatic differentiation, by differed cached evaluation.
object AutoDiff {
// When we are computing derivative for say f(x) we would like to keep track of dx/dt (t being independent variable)
// in-case t == x then diff = 1.
// Suppose x is not a scalar ie., R^k then we would like to keep track of partial derivatives on each co-ordinate.
// Hence V is a vector space (which is just enough for our case), this would be final Jacobian
case class Dual[K, V](value: K, diff: V)
// We define a Vectorspace V over the field K, we can reduce the generalization by keeping K == Double and it should mostly work too.
trait VectorSpace[K, V]:
def zero: V
extension (first: V)
def +(another: V): V
def *(scale: K): V
// Single component trivial vector space.
given doubleVector: VectorSpace[Double, Double] with
def zero = 0
extension (first: Double)
def +(another: Double) = first + another
def *(scale: Double) = first * scale
// We define Dual to be a number or substitutable to number.
// We need to generalize numeric to include transcendental numbers too.
trait ExtendedNumeric[N] extends Fractional[N]:
def sin(x: N): N
def cos(x: N): N
def log(x: N): N
def exp(x: N): N
type Num[T] = ExtendedNumeric[T]
// We just generalize Doubles here
given doubleIsExtendedNumeric: Num[Double] with Numeric.DoubleIsFractional with Ordering.Double.IeeeOrdering with
def sin(x: Double) = math.sin(x)
def cos(x: Double) = math.cos(x)
def log(x: Double) = math.log(x)
def exp(x: Double) = math.exp(x)
import Fractional.Implicits.{given}
// Dual is an extended number with all glory, given K is an extended number and V is a vector space.
given numericDual[K, V](using n: Num[K], v: VectorSpace[K, V]) : Num[Dual[K, V]] with
def minusOne: K = n.negate(n.one)
def plus(x: Dual[K, V], y: Dual[K, V]): Dual[K, V] = Dual(x.value + y.value, x.diff + y.diff)
def minus(x: Dual[K, V], y: Dual[K, V]): Dual[K, V] = Dual(x.value - y.value, x.diff + (y.diff * minusOne))
def times(x: Dual[K, V], y: Dual[K, V]): Dual[K, V] = Dual(x.value * y.value, x.diff * y.value + y.diff * x.value)
def negate(x: Dual[K, V]): Dual[K, V] = Dual(x.value, x.diff * minusOne)
def div(x: Dual[K, V], y: Dual[K, V]) = Dual(x.value / y.value, (x.diff * y.value + y.diff * minusOne * x.value) * (n.one / (y.value * y.value)))
def fromInt(x: Int): Dual[K, V] = Dual(n.fromInt(x), v.zero)
def parseString(str: String): Option[Dual[K, V]] = n.parseString(str).map(Dual(_, v.zero))
def toInt(x: Dual[K, V]): Int = n.toInt(x.value)
def toLong(x: Dual[K, V]): Long = n.toLong(x.value)
def toFloat(x: Dual[K, V]): Float = n.toFloat(x.value)
def toDouble(x: Dual[K, V]): Double = n.toDouble(x.value)
def compare(left: Dual[K, V], right: Dual[K, V]) = n.compare(left.value, right.value)
def sin(x: Dual[K, V]): Dual[K, V] = Dual(n.sin(x.value), x.diff * n.cos(x.value))
def cos(x: Dual[K, V]): Dual[K, V] = Dual(n.cos(x.value), x.diff * (minusOne * n.sin(x.value)))
def log(x: Dual[K, V]): Dual[K, V] = Dual(n.log(x.value), x.diff * (n.one / x.value))
def exp(x: Dual[K, V]): Dual[K, V] = Dual(n.exp(x.value), x.diff * n.exp(x.value))
// We could already compute derivatives like
// diff(f) = x => f(dual(x, 1)).diff
// In-order to optimally calculate derivatives with multiple independent variables, we would delay computing the Jacobian like this.
sealed trait Delta
case object Zero extends Delta
case class OneHot(i: Int) extends Delta
case class Scale(scale: Double, vector: Delta) extends Delta
case class Add(left: Delta, right: Delta) extends Delta
case class Var(id: Long) extends Delta
case class Let(id: Long, value: Delta, block: Delta) extends Delta
// Delta is a vector space
given VectorSpace[Double, Delta] with
def zero: Delta = Zero
extension (first: Delta)
def +(another: Delta): Delta = Add(first, another)
def *(scale: Double): Delta = Scale(scale, first)
// Scala doesn't auto calls fromInt it looks like not sure why they are there in first place if it's not called
// Sigh !!!
given Conversion[Double, Dual[Double, Delta]] with
override def apply(x: Double): Dual[Double, Delta] = Dual(x, Zero)
type DeltaMap = Map[Long, Array[Double]]
def eval(dimensions: Int, delta: Delta, mappings: DeltaMap): Array[Double] = delta match
case Zero => Array.fill(dimensions)(0.0)
case OneHot(i) => {
val array = Array.fill(dimensions)(0.0)
array.update(i, 1)
array
}
case Scale(scale, vector) => eval(dimensions, vector, mappings).map(_ * scale)
case Add(left, right) => {
val l = eval(dimensions, left, mappings)
val r = eval(dimensions, right, mappings)
val result = Array.copyOf[Double](l, dimensions)
for (i <- 0.until(dimensions)) result.update(i, result(i) + r(i))
result
}
case Var(id) => mappings(id) // If we clone here we can happily mutate stuff above.
case Let(id, value, block) => eval(dimensions, block, mappings + (id -> eval(dimensions, value, mappings)))
def diffV(f: Array[Dual[Double, Delta]] => Array[Dual[Double, Delta]]): Array[Double] => Array[Array[Double]] = {
value => {
val fresult = f(value.zipWithIndex.map{(v, i) => Dual(v, OneHot(i))})
for { fval <- fresult } yield eval(value.length, fval.diff, Map.empty)
}
}
def diff(f: Dual[Double, Delta] => Dual[Double, Delta]): Double => Double = v => eval(1, f(Dual(v, OneHot(0))).diff, Map.empty)(0)
def main(args: Array[String]) = {
import scala.language.implicitConversions
println(diff(x => x * x + x * 2.0 + 3.0)(2))
diffV{case Array(x, y) => Array(x * x * x + y * 2.0, x * y, x + y)}(Array(2.0, 1.0))
.foreach(row => println(row.map(String.format("%10.2f", _)).mkString))
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment