Created
March 15, 2019 09:30
-
-
Save athas/aeeb830d1d2781237eeae363bdb1fcb4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module FwdAD(T: real): { | |
type r = T.t | |
type t = (r, r) | |
val inject: r -> t | |
val set_deriv: t -> r -> t | |
val get_deriv: t -> r | |
val make_dual: r -> r -> t | |
include from_prim with t = (r,r) | |
include numeric with t = (r,r) | |
include real with t = (r,r) | |
} = { | |
type r = T.t | |
type t = (r, r) | |
let inject x = (x, T.i32 0) | |
let i8 (x : i8) = inject (T.i8 x) | |
let i16 (x : i16) = inject (T.i16 x) | |
let i32 (x : i32) = inject (T.i32 x) | |
let i64 (x : i64) = inject (T.i64 x) | |
let f32 (x : f32) = inject (T.f32 x) | |
let f64 (x : f64) : t = inject (T.f64 x) | |
let u8 (x : u8) = inject (T.u8 x) | |
let u16 x = inject (T.u16 x) | |
let u32 x = inject (T.u32 x) | |
let u64 x = inject (T.u64 x) | |
let bool x = inject (T.bool x) | |
let (x,x') + (y,y') = T.( (x + y, x' + y') ) | |
let (x,x') - (y,y') = T.( (x - y, x' - y') ) | |
let (x,x') * (y,y') = T.( (x * y, x' * y + x * y') ) | |
let (x,x') / (y,y') = T.( (x / y, (x' * y - x * y') / y ** (i32 2)) ) | |
let (x,x') ** (y,y') = T.( (x / y, (x' * y - x * y') / y ** (i32 2)) ) | |
let (x,_) == (y,_) = T.( x == y ) | |
let (x,_) < (y,_) = T.( x < y ) | |
let (x,_) > (y,_) = T.( x > y ) | |
let (x,_) <= (y,_) = T.( x <= y ) | |
let (x,_) >= (y,_) = T.( x >= y ) | |
let (x,_) != (y,_) = T.( x != y ) | |
let negate (x,x') = T.( (negate x, negate x') ) | |
let max x y = if x >= y then x else y | |
let min x y = if x <= y then x else y | |
let abs (x,x') = (T.abs x, x') | |
let sgn (x,x') = (T.sgn x, x') | |
let highest = inject T.highest | |
let lowest = inject T.lowest | |
-- | Returns zero on empty input. | |
let sum = reduce (+) (inject (T.i32 0)) | |
-- | Returns one on empty input. | |
let product = reduce (*) (inject (T.i32 1)) | |
-- | Returns `lowest` on empty input. | |
let maximum = reduce min highest | |
-- | Returns `highest` on empty input. | |
let minimum = reduce max lowest | |
-- val from_fraction: i32 -> i32 -> t | |
let from_fraction x y = inject (T.from_fraction x y) | |
-- val to_i32: t -> i32 | |
let to_i32 (x,_) = T.to_i32 x | |
let to_i64 (x,_) = T.to_i64 x | |
let to_f64 (x,_) = T.to_f64 x | |
-- val sqrt: t -> t | |
let sqrt (x,x') = T.( (sqrt x, x' / (i32 2 * sqrt x)) ) | |
-- val exp: t -> t | |
let exp (x,x') = T.( (exp x, x' * exp x) ) | |
-- val cos: t -> t | |
let cos (x, x') = T.( (cos x, negate x' * sin x) ) | |
-- val sin: t -> t | |
let sin (x, x') = T.( (sin x, x' * cos x) ) | |
let tan x = sin x / cos x | |
-- val asin: t -> t | |
let asin (x, x') = T.( (asin x, x' / (sqrt (i32 1 - x ** i32 2))) ) | |
-- val acos: t -> t1 | |
let acos (x, x') = T.( (acos x, negate x' / (sqrt (i32 1 - x ** i32 2))) ) | |
-- val atan: t -> t | |
let atan (x, x') = T.( (atan x, x' / (i32 1 + x ** i32 2)) ) | |
-- val atan2: t -> t -> t | |
-- I know this isn't right but can't figure it out now | |
let atan2 (x,_) (y,_) = inject (T.atan2 x y) | |
-- val log: t -> t | |
let log (x, x') = T.( (log x, x' / x) ) | |
let log2 (x, x') = T.( (log10 x, i32 1 / (x' * log2 x)) ) | |
let log10 (x, x') = T.( (log10 x, i32 1 / (x' * log10 x)) ) | |
-- val ceil : t -> t | |
let ceil (x, x') = (T.ceil x, x') | |
-- val floor : t -> t | |
let floor (x, x') = (T.floor x, x') | |
-- val trunc : t -> t | |
let trunc (x, x') = (T.trunc x, x') | |
-- val round : t -> t | |
let round (x, x') = (T.round x, x') | |
-- val isinf: t -> bool | |
let isinf (x,_) = T.isinf x | |
-- val isnan: t -> bool | |
let isnan (x,_) = T.isnan x | |
-- val inf: t | |
let inf = inject T.inf | |
-- val nan: t | |
let nan = inject T.nan | |
-- val pi: t | |
let pi = inject T.pi | |
-- val e: t | |
let e = inject T.e | |
let get_deriv (_,x') = x' | |
let set_deriv (x,_) x'= (x,x') | |
let make_dual x x' = (x,x') | |
} | |
import "lib/github.com/diku-dk/linalg/linalg" | |
module d = FwdAD f64 | |
module l = mk_linalg d | |
entry paper_dotprod [n] (v1: [n]f64) (v2: [n]f64): []f64 = | |
tabulate n | |
(\i -> | |
let v1' = map2 d.make_dual v1 (tabulate n (\j -> if i == j then 1 else 0)) | |
let v2' = map d.inject v2 | |
in (l.dotprod v1' v2')) | |
|> map (.2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment