Last active
January 5, 2024 19:59
-
-
Save jasigal/7afc5fa8d0fc8422dc48a92dba2668ae to your computer and use it in GitHub Desktop.
Reverse mode AD in Koka
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pub module approx-recip | |
import smooth | |
fun approx-recip(iters : int, x : a) : <smooth<a>,div> a { | |
var acc := c(1.0) | |
var prev := c(1.0) | |
repeat(iters) { | |
prev := (prev *. (~.)(x -. c(1.0))) | |
acc := (acc +. prev) | |
} | |
acc | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pub module evaluate | |
import std/num/float64 | |
import smooth | |
val evaluate = handler { | |
ctl ap0(n) -> match(n) {Const(i) -> resume(i)} | |
ctl ap1(u,x) -> match(u) { | |
Negate -> resume(~x : float64) | |
Sin -> resume(sin(x) : float64) | |
Cos -> resume(cos(x) : float64) | |
Exp -> resume(exp(x) : float64) | |
} | |
ctl ap2(b,x,y) -> match(b) { | |
Plus -> resume(x + y : float64) | |
Subtract -> resume(x - y : float64) | |
Times -> resume(x * y : float64) | |
Divide -> resume(x / y : float64) | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pub module reverse-taylor-recip-benchmark | |
import std/os/env | |
import std/text/parse | |
import smooth | |
import evaluate | |
import reverse | |
import approx-recip | |
fun main() | |
val sz = | |
match get-args() | |
Nil -> 500 | |
Cons(arg0, _) -> | |
match parse-int(arg0) | |
Just(sz) -> sz | |
_ -> 500 | |
with evaluate | |
val res = grad(fn(x) {approx-recip(sz, x)}, c(0.5)) | |
println(res : float64) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pub module reverse | |
import smooth | |
value type prop<h,a> { | |
Prop(v : a, dv : ref<h, a>) | |
} | |
val reverse = handler { | |
ctl ap0(n) -> { | |
val r = Prop(op0(n), ref(c(0.0))) | |
resume(r) | |
} | |
ctl ap1(u,x) -> { | |
val r = Prop(op1(u,x.v), ref(c(0.0))) | |
resume(r) | |
set(x.dv, !x.dv +. (der1(u,x.v) *. !r.dv)) | |
} | |
ctl ap2(b,x,y) -> { | |
val r = Prop(op2(b,x.v,y.v), ref(c(0.0))) | |
resume(r) | |
set(x.dv, !x.dv +. (der2(b,L,x.v,y.v) *. !r.dv)) | |
set(y.dv, !y.dv +. (der2(b,R,x.v,y.v) *. !r.dv)) | |
} | |
} | |
fun grad(f, x) { | |
val z = Prop(x, ref(c(0.0))) | |
reverse{set(f(z).dv, mask<smooth>{c(1.0)})} | |
!z.dv | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pub module smooth | |
import std/num/float64 | |
pub infixl 6 (+.) | |
pub infixl 7 (*.) | |
type nullary { | |
Const(x : float64) | |
} | |
type unary { | |
Negate | |
Sin | |
Cos | |
Exp | |
} | |
type binary { | |
Plus | |
Subtract | |
Times | |
Divide | |
} | |
type arg { | |
L | |
R | |
} | |
effect smooth<a> { | |
ctl ap0(n : nullary) : a | |
ctl ap1(u : unary, arg : a) : a | |
ctl ap2(b : binary, arg1 : a, arg2 : a) : a | |
} | |
inline fun c(i : float64) : smooth<a> a { | |
ap0(Const(i)) | |
} | |
inline fun (~.)(x : a) : smooth<a> a { | |
ap1(Negate, x) | |
} | |
inline fun sin_(x : a) : smooth<a> a { | |
ap1(Sin, x) | |
} | |
inline fun cos_(x : a) : smooth<a> a { | |
ap1(Cos, x) | |
} | |
inline fun exp_(x : a) : smooth<a> a { | |
ap1(Exp, x) | |
} | |
inline fun (+.)(x : a, y : a) : smooth<a> a { | |
ap2(Plus, x, y) | |
} | |
inline fun (-.)(x : a, y : a) : smooth<a> a { | |
ap2(Subtract, x, y) | |
} | |
inline fun (*.)(x : a, y : a) : smooth<a> a { | |
ap2(Times, x, y) | |
} | |
inline fun div_(x : a, y : a) : smooth<a> a { | |
ap2(Divide, x, y) | |
} | |
inline fun op0(n) { | |
match(n) { | |
Const(x) -> c(x) | |
} | |
} | |
inline fun op1(u, x) { | |
match(u) { | |
Negate -> (~.)(x) | |
Sin -> sin_(x) | |
Cos -> cos_(x) | |
Exp -> exp_(x) | |
} | |
} | |
inline fun op2(b, x, y) { | |
match(b) { | |
Plus -> x +. y | |
Subtract -> x -. y | |
Times -> x *. y | |
Divide -> div_(x, y) | |
} | |
} | |
inline fun der1(u, x) { | |
match(u) { | |
Negate -> (~.)(c(1.0)) | |
Sin -> cos_(x) | |
Cos -> (~.)(sin_(x)) | |
Exp -> (~.)(c(1.0)) | |
} | |
} | |
inline fun der2(b, a, x, y) { | |
match(b) { | |
Plus -> match(a) {L -> c(1.0); R -> c(1.0)} | |
Subtract -> match(a) {L -> c(1.0); R -> c(~1.0)} | |
Times -> match(a) {L -> y; R -> x} | |
Divide -> match(a) {L -> div_(c(1.0), y); R -> div_((~.)(x), y *. y)} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Seems to be missing approx-recip?