Skip to content

Instantly share code, notes, and snippets.

@jasigal
Last active January 5, 2024 19:59
Show Gist options
  • Save jasigal/7afc5fa8d0fc8422dc48a92dba2668ae to your computer and use it in GitHub Desktop.
Save jasigal/7afc5fa8d0fc8422dc48a92dba2668ae to your computer and use it in GitHub Desktop.
Reverse mode AD in Koka
pub module approx-recip
import smooth
fun approx-recip(iters : int, x : a) : <smooth<a>,div> a {
var acc := c(1.0)
var prev := c(1.0)
repeat(iters) {
prev := (prev *. (~.)(x -. c(1.0)))
acc := (acc +. prev)
}
acc
}
pub module evaluate
import std/num/float64
import smooth
val evaluate = handler {
ctl ap0(n) -> match(n) {Const(i) -> resume(i)}
ctl ap1(u,x) -> match(u) {
Negate -> resume(~x : float64)
Sin -> resume(sin(x) : float64)
Cos -> resume(cos(x) : float64)
Exp -> resume(exp(x) : float64)
}
ctl ap2(b,x,y) -> match(b) {
Plus -> resume(x + y : float64)
Subtract -> resume(x - y : float64)
Times -> resume(x * y : float64)
Divide -> resume(x / y : float64)
}
}
pub module reverse-taylor-recip-benchmark
import std/os/env
import std/text/parse
import smooth
import evaluate
import reverse
import approx-recip
fun main()
val sz =
match get-args()
Nil -> 500
Cons(arg0, _) ->
match parse-int(arg0)
Just(sz) -> sz
_ -> 500
with evaluate
val res = grad(fn(x) {approx-recip(sz, x)}, c(0.5))
println(res : float64)
pub module reverse
import smooth
value type prop<h,a> {
Prop(v : a, dv : ref<h, a>)
}
val reverse = handler {
ctl ap0(n) -> {
val r = Prop(op0(n), ref(c(0.0)))
resume(r)
}
ctl ap1(u,x) -> {
val r = Prop(op1(u,x.v), ref(c(0.0)))
resume(r)
set(x.dv, !x.dv +. (der1(u,x.v) *. !r.dv))
}
ctl ap2(b,x,y) -> {
val r = Prop(op2(b,x.v,y.v), ref(c(0.0)))
resume(r)
set(x.dv, !x.dv +. (der2(b,L,x.v,y.v) *. !r.dv))
set(y.dv, !y.dv +. (der2(b,R,x.v,y.v) *. !r.dv))
}
}
fun grad(f, x) {
val z = Prop(x, ref(c(0.0)))
reverse{set(f(z).dv, mask<smooth>{c(1.0)})}
!z.dv
}
pub module smooth
import std/num/float64
pub infixl 6 (+.)
pub infixl 7 (*.)
type nullary {
Const(x : float64)
}
type unary {
Negate
Sin
Cos
Exp
}
type binary {
Plus
Subtract
Times
Divide
}
type arg {
L
R
}
effect smooth<a> {
ctl ap0(n : nullary) : a
ctl ap1(u : unary, arg : a) : a
ctl ap2(b : binary, arg1 : a, arg2 : a) : a
}
inline fun c(i : float64) : smooth<a> a {
ap0(Const(i))
}
inline fun (~.)(x : a) : smooth<a> a {
ap1(Negate, x)
}
inline fun sin_(x : a) : smooth<a> a {
ap1(Sin, x)
}
inline fun cos_(x : a) : smooth<a> a {
ap1(Cos, x)
}
inline fun exp_(x : a) : smooth<a> a {
ap1(Exp, x)
}
inline fun (+.)(x : a, y : a) : smooth<a> a {
ap2(Plus, x, y)
}
inline fun (-.)(x : a, y : a) : smooth<a> a {
ap2(Subtract, x, y)
}
inline fun (*.)(x : a, y : a) : smooth<a> a {
ap2(Times, x, y)
}
inline fun div_(x : a, y : a) : smooth<a> a {
ap2(Divide, x, y)
}
inline fun op0(n) {
match(n) {
Const(x) -> c(x)
}
}
inline fun op1(u, x) {
match(u) {
Negate -> (~.)(x)
Sin -> sin_(x)
Cos -> cos_(x)
Exp -> exp_(x)
}
}
inline fun op2(b, x, y) {
match(b) {
Plus -> x +. y
Subtract -> x -. y
Times -> x *. y
Divide -> div_(x, y)
}
}
inline fun der1(u, x) {
match(u) {
Negate -> (~.)(c(1.0))
Sin -> cos_(x)
Cos -> (~.)(sin_(x))
Exp -> (~.)(c(1.0))
}
}
inline fun der2(b, a, x, y) {
match(b) {
Plus -> match(a) {L -> c(1.0); R -> c(1.0)}
Subtract -> match(a) {L -> c(1.0); R -> c(~1.0)}
Times -> match(a) {L -> y; R -> x}
Divide -> match(a) {L -> div_(c(1.0), y); R -> div_((~.)(x), y *. y)}
}
}
@TimWhiting
Copy link

Seems to be missing approx-recip?

@jasigal
Copy link
Author

jasigal commented Jan 5, 2024

Whoops! Added it now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment