Skip to content

Instantly share code, notes, and snippets.

@elizarov
Last active February 26, 2023 15:18
Show Gist options
  • Star 5 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d to your computer and use it in GitHub Desktop.
Save elizarov/1ad3a8583e88cb6ea7a0ad09bb591d3d to your computer and use it in GitHub Desktop.
Automatic Differentiation with Kotlin
/*
* Implementation of backward-mode automatic differentiation.
*/
/**
* Differentiable variable with value and derivative of differentiation ([grad]) result
* with respect to this variable.
*/
data class D(var x: Double, var d: Double = 0.0) {
constructor(x: Int): this(x.toDouble())
}
/**
* Runs differentiation and establishes [AD] context inside the block of code.
*
* Example:
* ```
* val x = D(2) // define variable(s) and their values
* val y = grad { sqr(x) + 5 * x + 3 } // write formulate in grad context
* assertEquals(17.0, y.x) // the value of result (y)
* assertEquals(9.0, x.d) // dy/dx
* ```
*/
fun grad(body: AD.() -> D): D =
ADImpl().run {
val result = body()
result.d = 1.0 // computing derivative w.r.t result
runBackwardPass()
result
}
/**
* Automatic Differentiation context class.
*/
abstract class AD {
/**
* Performs update of derivative after the rest of the formula in the back-pass.
*
* For example, implementation of `sin` function is:
*
* ```
* fun AD.sin(x: D): D = derive(D(sin(x.x)) { z -> // call derive with function result
* x.d += z.d * cos(x.x) // update derivative using chain rule and derivative of the function
* }
* ```
*/
abstract fun <R> derive(value: R, block: (R) -> Unit): R
// Basic math (+, -, *, /)
operator fun D.plus(that: D): D = derive(D(this.x + that.x)) { z ->
this.d += z.d
that.d += z.d
}
operator fun D.minus(that: D): D = derive(D(this.x - that.x)) { z ->
this.d += z.d
that.d -= z.d
}
operator fun D.times(that: D): D = derive(D(this.x * that.x)) { z ->
this.d += z.d * that.x
that.d += z.d * this.x
}
operator fun D.div(that: D): D = derive(D(this.x / that.x)) { z ->
this.d += z.d / that.x
that.d -= z.d * this.x / (that.x * that.x)
}
// Overloads for Double constants
operator fun Double.plus(that: D): D = derive(D(this + that.x)) { z ->
that.d += z.d
}
operator fun D.plus(b: Double): D = b.plus(this)
operator fun Double.minus(that: D): D = derive(D(this - that.x)) { z ->
that.d -= z.d
}
operator fun D.minus(that: Double): D = derive(D(this.x - that)) { z ->
this.d += z.d
}
operator fun Double.times(that: D): D = derive(D(this * that.x)) { z ->
that.d += z.d * this
}
operator fun D.times(b: Double): D = b.times(this)
operator fun Double.div(that: D): D = derive(D(this / that.x)) { z ->
that.d -= z.d * this / (that.x * that.x)
}
operator fun D.div(that: Double): D = derive(D(this.x / that)) { z ->
this.d += z.d / that
}
// Overloads for Int constants
operator fun Int.plus(b: D): D = toDouble().plus(b)
operator fun D.plus(b: Int): D = plus(b.toDouble())
operator fun Int.minus(b: D): D = toDouble().minus(b)
operator fun D.minus(b: Int): D = minus(b.toDouble())
operator fun Int.times(b: D): D = toDouble().times(b)
operator fun D.times(b: Int): D = times(b.toDouble())
operator fun Int.div(b: D): D = toDouble().div(b)
operator fun D.div(b: Int): D = div(b.toDouble())
}
// ---------------------------------------- ENGINE IMPLEMENTATION ----------------------------------------
// Private implementation class
private class ADImpl : AD() {
// this stack contains pairs of blocks and values to apply them to
private var stack = arrayOfNulls<Any?>(8)
private var sp = 0
@Suppress("UNCHECKED_CAST")
override fun <R> derive(value: R, block: (R) -> Unit): R {
// save block to stack for backward pass
if (sp >= stack.size) stack = stack.copyOf(stack.size * 2)
stack[sp++] = block
stack[sp++] = value
return value
}
@Suppress("UNCHECKED_CAST")
fun runBackwardPass() {
while (sp > 0) {
val value = stack[--sp]
val block = stack[--sp] as (Any?) -> Unit
block(value)
}
}
}
import kotlin.math.*
// Extensions for differentiation of various basic mathematical functions
// x ^ 2
fun AD.sqr(x: D): D = derive(D(x.x * x.x)) { z ->
x.d += z.d * 2 * x.x
}
// x ^ 1/2
fun AD.sqrt(x: D): D = derive(D(sqrt(x.x))) { z ->
x.d += z.d * 0.5 / z.x
}
// x ^ y (const)
fun AD.pow(x: D, y: Double): D = derive(D(x.x.pow(y))) { z ->
x.d += z.d * y * x.x.pow(y - 1)
}
fun AD.pow(x: D, y: Int): D = pow(x, y.toDouble())
// exp(x)
fun AD.exp(x: D): D = derive(D(exp(x.x))) { z ->
x.d += z.d * z.x
}
// ln(x)
fun AD.ln(x: D): D = derive(D(ln(x.x))) { z ->
x.d += z.d / x.x
}
// x ^ y (any)
fun AD.pow(x: D, y: D): D = exp(y * ln(x))
// sin(x)
fun AD.sin(x: D): D = derive(D(sin(x.x))) { z ->
x.d += z.d * cos(x.x)
}
// cos(x)
fun AD.cos(x: D): D = derive(D(cos(x.x))) { z ->
x.d -= z.d * sin(x.x)
}
import org.junit.*
import kotlin.math.*
import kotlin.test.*
class ADTest {
@Test
fun testPlusX2() {
val x = D(3) // diff w.r.t this x at 3
val y = grad { x + x }
assertEquals(6.0, y.x) // y = x + x = 6
assertEquals(2.0, x.d) // dy/dx = 2
}
@Test
fun testPlus() {
// two variables
val x = D(2)
val y = D(3)
val z = grad { x + y }
assertEquals(5.0, z.x) // z = x + y = 5
assertEquals(1.0, x.d) // dz/dx = 1
assertEquals(1.0, y.d) // dz/dy = 1
}
@Test
fun testMinus() {
// two variables
val x = D(7)
val y = D(3)
val z = grad { x - y }
assertEquals(4.0, z.x) // z = x - y = 4
assertEquals(1.0, x.d) // dz/dx = 1
assertEquals(-1.0, y.d) // dz/dy = -1
}
@Test
fun testMulX2() {
val x = D(3) // diff w.r.t this x at 3
val y = grad { x * x }
assertEquals(9.0, y.x) // y = x * x = 9
assertEquals(6.0, x.d) // dy/dx = 2 * x = 7
}
@Test
fun testSqr() {
val x = D(3)
val y = grad { sqr(x) }
assertEquals(9.0, y.x) // y = x ^ 2 = 9
assertEquals(6.0, x.d) // dy/dx = 2 * x = 7
}
@Test
fun testSqrSqr() {
val x = D(2)
val y = grad { sqr(sqr(x)) }
assertEquals(16.0, y.x) // y = x ^ 4 = 16
assertEquals(32.0, x.d) // dy/dx = 4 * x^3 = 32
}
@Test
fun testX3() {
val x = D(2) // diff w.r.t this x at 2
val y = grad { x * x * x }
assertEquals(8.0, y.x) // y = x * x * x = 8
assertEquals(12.0, x.d) // dy/dx = 3 * x * x = 12
}
@Test
fun testDiv() {
val x = D(5)
val y = D(2)
val z = grad { x / y }
assertEquals(2.5, z.x) // z = x / y = 2.5
assertEquals(0.5, x.d) // dz/dx = 1 / y = 0.5
assertEquals(-1.25, y.d) // dz/dy = -x / y^2 = -1.25
}
@Test
fun testPow3() {
val x = D(2) // diff w.r.t this x at 2
val y = grad { pow(x, 3) }
assertEquals(8.0, y.x) // y = x ^ 3 = 8
assertEquals(12.0, x.d) // dy/dx = 3 * x ^ 2 = 12
}
@Test
fun testPowFull() {
val x = D(2)
val y = D(3)
val z = grad { pow(x, y) }
assertApprox(8.0, z.x) // z = x ^ y = 8
assertApprox(12.0, x.d) // dz/dx = y * x ^ (y - 1) = 12
assertApprox(8.0 * ln(2.0), y.d) // dz/dy = x ^ y * ln(x)
}
@Test
fun testFromPaper() {
val x = D(3)
val y = grad { 2 * x + x * x * x }
assertEquals(33.0, y.x) // y = 2 * x + x * x * x = 33
assertEquals(29.0, x.d) // dy/dx = 2 + 3 * x * x = 29
}
@Test
fun testLongChain() {
val n = 10_000
val x = D(1)
val y = grad {
var pow = D(1)
for (i in 1..n) pow *= x
pow
}
assertEquals(1.0, y.x) // y = x ^ n = 1
assertEquals(n.toDouble(), x.d) // dy/dx = n * x ^ (n - 1) = n - 1
}
@Test
fun testExample() {
val x = D(2)
val y = grad { sqr(x) + 5 * x + 3 }
assertEquals(17.0, y.x) // the value of result (y)
assertEquals(9.0, x.d) // dy/dx
}
@Test
fun testSqrt() {
val x = D(16)
val y = grad { sqrt(x) }
assertEquals(4.0, y.x) // y = x ^ 1/2 = 4
assertEquals(1.0 / 8, x.d) // dy/dx = 1/2 / x ^ 1/4 = 1/8
}
@Test
fun testSin() {
val x = D(PI / 6)
val y = grad { sin(x) }
assertApprox(0.5, y.x) // y = sin(PI/6) = 0.5
assertApprox(sqrt(3.0) / 2, x.d) // dy/dx = cos(PI/6) = sqrt(3)/2
}
@Test
fun testCos() {
val x = D(PI / 6)
val y = grad { cos(x) }
assertApprox(sqrt(3.0) / 2, y.x) // y = cos(PI/6) = sqrt(3)/2
assertApprox(-0.5, x.d) // dy/dx = -sin(PI/6) = -0.5
}
private fun assertApprox(a: Double, b: Double) {
if ((a - b) > 1e-10) assertEquals(a, b)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment