Skip to content

Instantly share code, notes, and snippets.

@kayceesrk
Last active February 21, 2018 13:15
Show Gist options
  • Star 3 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kayceesrk/f005d115f7015e31f5e977a9678e9672 to your computer and use it in GitHub Desktop.
Save kayceesrk/f005d115f7015e31f5e977a9678e9672 to your computer and use it in GitHub Desktop.
(* Reverse-mode algorithmic differentiation using effect handlers.
Adapted from https://twitter.com/tiarkrompf/status/963314799521222656.
See https://openreview.net/forum?id=SJxJtYkPG for more information. *)
module F = struct
type t = { v : float; mutable d : float }
let mk v = {v; d = 0.0}
effect Plus : t * t -> t
effect Mult : t * t -> t
let grad f x =
let x = mk x in
begin match f x with
| r -> r.d <- 1.0; r
| effect (Plus(a,b)) k ->
let x = {v = a.v +. b.v; d = 0.0} in
ignore (continue k x);
a.d <- a.d +. x.d;
b.d <- b.d +. x.d;
x
| effect (Mult(a,b)) k ->
let x = {v = a.v *. b.v; d = 0.0} in
ignore (continue k x);
a.d <- a.d +. (b.v *. x.d);
b.d <- b.d +. (a.v *. x.d);
x
end;
x.d
let (+.) a b = perform (Plus(a,b))
let ( *. ) a b = perform (Mult(a,b))
end;;
for x = 0 to 10 do
let x = float_of_int x in
assert (F.(grad (fun x -> x +. x *. x *. x) x) =
1.0 +. 3.0 *. x *. x)
done
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment