Last active
November 16, 2018 20:21
-
-
Save 2torus/3a2ca111f6c1991d6ca366ec1bd5b39d to your computer and use it in GitHub Desktop.
Automatic Differentiation in OCaml, similar to https://www.kaggle.com/borisettinger/gentle-introduction-to-automatic-differentiation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module type Numeric = | |
sig | |
type t | |
val (+) : t -> t -> t | |
val (-) : t -> t -> t | |
val ( * ): t -> t -> t | |
val (/): t -> t -> t | |
val c: float -> t | |
val abs: t -> float | |
end | |
module Float = | |
struct | |
type t = float | |
let (+) x y = x +. y | |
let (-) x y = x -. y | |
let ( * ) x y = x *. y | |
let (/) x y = x /. y | |
let c x = x | |
let abs x = abs_float x | |
end | |
type dual = {value:float; eps:float} | |
module Dual = | |
struct | |
type t = dual | |
let (+) x y = match x,y with | |
{value=v1; eps=e1}, {value=v2; eps=e2} -> {value= (v1 +. v2); eps = (e1 +.e2)} | |
let (-) x y = match x,y with | |
{value=v1; eps=e1}, {value=v2; eps=e2} -> {value= (v1 -. v2); eps = (e1 -.e2)} | |
let ( * ) x y = match x,y with | |
{value=v1; eps=e1}, {value=v2; eps=e2} -> {value= (v1 *. v2); eps = (e1 *.v2 +.e2 *. v1)} | |
let (/) x y = match x,y with | |
{value=v1; eps=e1}, {value=v2; eps=e2} -> {value= (v1 /. v2); eps = (e1 /. v2 -.e2 *. v1 /.(v2 *. v2))} | |
let c x = {value=x; eps=0.} | |
let abs {value; _} = abs_float value | |
end | |
module CustomSqrt(Dl: Numeric) = struct | |
open Dl | |
let epsilon = 1e-10 | |
let custom_sqrt ysq = | |
let rec custom_sqrt_iter xnext xprev = | |
let next_iter xn = (xn + ysq/xn) * c(0.5) in | |
let err = abs (xnext * xnext - xprev * xprev) in | |
if err < epsilon then xnext else | |
custom_sqrt_iter (next_iter xnext) xnext | |
in custom_sqrt_iter ysq (c(0.)) | |
end | |
module FlSq = CustomSqrt(Float) | |
module DlSq = CustomSqrt(Dual) | |
FlSq.custom_sqrt 4., DlSq.custom_sqrt {value=4.; eps = 1.} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment