Skip to content

Instantly share code, notes, and snippets.

@athas
Created March 15, 2019 09:30
Show Gist options
  • Save athas/aeeb830d1d2781237eeae363bdb1fcb4 to your computer and use it in GitHub Desktop.
Save athas/aeeb830d1d2781237eeae363bdb1fcb4 to your computer and use it in GitHub Desktop.
module FwdAD(T: real): {
type r = T.t
type t = (r, r)
val inject: r -> t
val set_deriv: t -> r -> t
val get_deriv: t -> r
val make_dual: r -> r -> t
include from_prim with t = (r,r)
include numeric with t = (r,r)
include real with t = (r,r)
} = {
type r = T.t
type t = (r, r)
let inject x = (x, T.i32 0)
let i8 (x : i8) = inject (T.i8 x)
let i16 (x : i16) = inject (T.i16 x)
let i32 (x : i32) = inject (T.i32 x)
let i64 (x : i64) = inject (T.i64 x)
let f32 (x : f32) = inject (T.f32 x)
let f64 (x : f64) : t = inject (T.f64 x)
let u8 (x : u8) = inject (T.u8 x)
let u16 x = inject (T.u16 x)
let u32 x = inject (T.u32 x)
let u64 x = inject (T.u64 x)
let bool x = inject (T.bool x)
let (x,x') + (y,y') = T.( (x + y, x' + y') )
let (x,x') - (y,y') = T.( (x - y, x' - y') )
let (x,x') * (y,y') = T.( (x * y, x' * y + x * y') )
let (x,x') / (y,y') = T.( (x / y, (x' * y - x * y') / y ** (i32 2)) )
let (x,x') ** (y,y') = T.( (x / y, (x' * y - x * y') / y ** (i32 2)) )
let (x,_) == (y,_) = T.( x == y )
let (x,_) < (y,_) = T.( x < y )
let (x,_) > (y,_) = T.( x > y )
let (x,_) <= (y,_) = T.( x <= y )
let (x,_) >= (y,_) = T.( x >= y )
let (x,_) != (y,_) = T.( x != y )
let negate (x,x') = T.( (negate x, negate x') )
let max x y = if x >= y then x else y
let min x y = if x <= y then x else y
let abs (x,x') = (T.abs x, x')
let sgn (x,x') = (T.sgn x, x')
let highest = inject T.highest
let lowest = inject T.lowest
-- | Returns zero on empty input.
let sum = reduce (+) (inject (T.i32 0))
-- | Returns one on empty input.
let product = reduce (*) (inject (T.i32 1))
-- | Returns `lowest` on empty input.
let maximum = reduce min highest
-- | Returns `highest` on empty input.
let minimum = reduce max lowest
-- val from_fraction: i32 -> i32 -> t
let from_fraction x y = inject (T.from_fraction x y)
-- val to_i32: t -> i32
let to_i32 (x,_) = T.to_i32 x
let to_i64 (x,_) = T.to_i64 x
let to_f64 (x,_) = T.to_f64 x
-- val sqrt: t -> t
let sqrt (x,x') = T.( (sqrt x, x' / (i32 2 * sqrt x)) )
-- val exp: t -> t
let exp (x,x') = T.( (exp x, x' * exp x) )
-- val cos: t -> t
let cos (x, x') = T.( (cos x, negate x' * sin x) )
-- val sin: t -> t
let sin (x, x') = T.( (sin x, x' * cos x) )
let tan x = sin x / cos x
-- val asin: t -> t
let asin (x, x') = T.( (asin x, x' / (sqrt (i32 1 - x ** i32 2))) )
-- val acos: t -> t1
let acos (x, x') = T.( (acos x, negate x' / (sqrt (i32 1 - x ** i32 2))) )
-- val atan: t -> t
let atan (x, x') = T.( (atan x, x' / (i32 1 + x ** i32 2)) )
-- val atan2: t -> t -> t
-- I know this isn't right but can't figure it out now
let atan2 (x,_) (y,_) = inject (T.atan2 x y)
-- val log: t -> t
let log (x, x') = T.( (log x, x' / x) )
let log2 (x, x') = T.( (log10 x, i32 1 / (x' * log2 x)) )
let log10 (x, x') = T.( (log10 x, i32 1 / (x' * log10 x)) )
-- val ceil : t -> t
let ceil (x, x') = (T.ceil x, x')
-- val floor : t -> t
let floor (x, x') = (T.floor x, x')
-- val trunc : t -> t
let trunc (x, x') = (T.trunc x, x')
-- val round : t -> t
let round (x, x') = (T.round x, x')
-- val isinf: t -> bool
let isinf (x,_) = T.isinf x
-- val isnan: t -> bool
let isnan (x,_) = T.isnan x
-- val inf: t
let inf = inject T.inf
-- val nan: t
let nan = inject T.nan
-- val pi: t
let pi = inject T.pi
-- val e: t
let e = inject T.e
let get_deriv (_,x') = x'
let set_deriv (x,_) x'= (x,x')
let make_dual x x' = (x,x')
}
import "lib/github.com/diku-dk/linalg/linalg"
module d = FwdAD f64
module l = mk_linalg d
entry paper_dotprod [n] (v1: [n]f64) (v2: [n]f64): []f64 =
tabulate n
(\i ->
let v1' = map2 d.make_dual v1 (tabulate n (\j -> if i == j then 1 else 0))
let v2' = map d.inject v2
in (l.dotprod v1' v2'))
|> map (.2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment