Created
May 18, 2023 14:25
-
-
Save adamse/59bb016d1f101ae2e020894d94091300 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env cabal | |
{- cabal: | |
build-depends: base, primitive | |
-} | |
{-# language BangPatterns #-} | |
import Data.IORef (IORef, newIORef, writeIORef, readIORef, atomicModifyIORef) | |
import System.IO.Unsafe (unsafePerformIO) | |
import Data.Primitive.PrimArray (MutablePrimArray, newPrimArray, writePrimArray, readPrimArray, replicatePrimArray, thawPrimArray) | |
import Control.Monad (zipWithM_, when, mapM) | |
fun x y z = x * y + exp z | |
-- dfun/dx = y | |
-- dfun/dy = x | |
-- dfun/dz = exp z | |
fun1 [x,y,z] = | |
let xy = x * y | |
expz = exp z | |
res = xy + expz | |
in res | |
-- chain rule | |
-- f = f(a, b, c) | |
-- Df = a' * D_a(f) + b' * D_b(f) + c' * D_c(f) | |
-- D(f + g) = f' * D_f(f + g) + g' * D_g(f + g) | |
-- = f' + g' | |
-- D_f(f + g) = 1 + 0 = 1 | |
-- D_g(f + g) = 0 + 1 = 1 | |
-- D(f * g) = f' * D_f(f * g) + g' * D_g(f * g) | |
-- = f'*g + g'*f | |
-- D_f(f * g) = g | |
-- D_g(f * g) = f | |
-- D(exp f) = f' * exp f | |
-- D_f(exp f) = exp f | |
-- * reverse mode ad | |
-- | |
-- https://web.archive.org/web/20161224122037/https://justindomke.wordpress.com/2009/03/24/a-simple-explanation-of-reverse-mode-automatic-differentiation/ | |
funbkw [x, y, z] = | |
let xy = x * y | |
expz = exp z | |
res = xy + expz | |
dres_dres = 1 | |
-- for expressions j where xy is an input: res | |
dres_dxy = dres_dres * pres_pxy where pres_pxy = 1 | |
-- for expressions j where expz is an input: res | |
dres_dexpz = dres_dres * pres_pexpz where pres_pexpz = 1 | |
-- for expressions j where x is an input: xy | |
dres_dx = dres_dxy * pxy_px where pxy_px = y | |
-- for expressions j where y is an input: xy | |
dres_dy = dres_dxy * pxy_py where pxy_py = x | |
-- for expressions j where z is an input: expz | |
dres_dz = dres_dexpz * pexpz_dz where pexpz_dz = expz | |
-- what information would we need to calc this automatically? | |
-- when evalling expz we need to push something into z to note that it was used? | |
in (res, [dres_dx, dres_dy, dres_dz]) | |
gun x = 2 * x ^ 2 | |
gun1 x = | |
let x2 = x ^ 2 | |
res = 2 * x2 | |
in res | |
gunbkw x = | |
let x2 = x ^ 2 -- write that we | |
res = 2 * x2 | |
dres_dres = 1 | |
dres_dx2 = dres_dres * pres_px2 where pres_px2 = 2 | |
dres_x = dres_dx2 * px2_px where px2_px = 2 * x | |
-- dres_dx = 2 * 2 * x | |
in (res, dres_x) | |
{- | |
the tape strategy works like this: every operation gets a number, for example for | |
exp (Rev var1 val1) | |
- we generate a new var2 | |
- we write to the tape Unary var1 (pexp_pval1 = exp val1), to note the partial derivative | |
- return (Rev var2 (exp val1)) | |
then on the way out when we're calculating the grad | |
we for each expression (var) add its contribution to its children | |
-} | |
-- toy implementation, not thread safe | |
type Var = Int | |
data Cell | |
= Nullary | |
| Unary !Var !Double -- contribution to the var from this unary computation | |
| Binary | |
!Var !Double -- contribution to the left hand side from this binary computation | |
!Var !Double -- contribution to the right hand side from this binary computation | |
deriving (Show) | |
data State = State !Var [Cell] | |
state0 = State 0 [] | |
-- this would be a thread local variable | |
{-# NOINLINE state #-} | |
state = unsafePerformIO $ newIORef state0 | |
addTape t = atomicModifyIORef state (\(State current ts) -> | |
let !next = succ current | |
!ts' = t:ts | |
!new = State next ts' | |
in (new, current)) | |
data Rev | |
= Const !Double | |
| Rev !Var !Double | |
nilop | |
:: Double | |
-> Rev | |
nilop x = Const x | |
unop | |
:: (Double -> Double) -- op x y = | |
-> (Double -> Double) -- dop/dx (partial derivative) | |
-> Rev -> Rev | |
unop op dop (Const x) = Const (op x) | |
unop op dop_dx (Rev varx x) = Rev (unsafePerformIO (addTape (Unary varx (dop_dx x)))) (op x) | |
binop | |
:: (Double -> Double -> Double) -- op x y = | |
-> (Double -> Double -> Double) -- dop/dx (partial derivative) | |
-> (Double -> Double -> Double) -- dop/dy (partial derivative) | |
-> Rev -> Rev -> Rev | |
binop op dop_dx dop_dy (Const x) (Const y) = Const (op x y) | |
binop op dop_dx dop_dy (Rev varx x) (Const y) = unop (\x -> op x y) (\x -> dop_dx x y) (Rev varx x) | |
binop op dop_dx dop_dy (Const x) (Rev vary y) = unop (op x) (dop_dy x) (Rev vary y) | |
binop op dop_dx dop_dy (Rev varx x) (Rev vary y) = Rev (unsafePerformIO (addTape (Binary varx (dop_dx x y) vary (dop_dy x y)))) (op x y) | |
instance Num Rev where | |
(+) = binop (+) (\_ _ -> 1) (\_ _ -> 1) | |
(-) = binop (-) (\_ _ -> 1) (\_ _ -> negate 1) | |
(*) = binop (*) (\_ y -> y) (\x _ -> x) | |
abs = unop abs signum | |
signum = error "signum" | |
fromInteger = nilop . fromInteger | |
instance Fractional Rev where | |
recip = unop recip (\x -> negate (recip (x * x))) | |
(/) x y = x * recip y | |
fromRational = nilop . fromRational | |
instance Floating Rev where | |
exp = unop exp exp | |
pi = nilop pi | |
log = unop log recip | |
sqrt = unop sqrt (\x -> recip (2 * sqrt x)) | |
(**) = binop (**) (\x y -> y * (x ** (y - 1))) (\x y -> log x * (x ** y)) | |
sin = unop sin cos | |
cos = unop cos (negate . sin) | |
asin = unop asin (\x -> recip (sqrt (1 - x*x))) | |
acos = unop acos (\x -> negate (recip (sqrt (1 - x * x)))) | |
atan = unop atan (\x -> recip (1 + x * x)) | |
sinh = unop sinh cosh | |
cosh = unop cosh sinh | |
asinh = unop asinh (\x -> recip (sqrt (x * x + 1))) | |
acosh = unop acosh (\x -> recip (sqrt (x * x - 1))) | |
atanh = unop atanh (\x -> recip (1 - x * x)) | |
assert m bool = when (not bool) (error m) | |
{-# noinline grad1 #-} | |
grad1 | |
:: (Rev -> Rev) | |
-> Double | |
-> IO (Double, Double) | |
grad1 op x = do | |
(res, [grad]) <- grad (\[x] -> op x) [x] | |
pure (res, grad) | |
{-# noinline grad #-} | |
grad | |
:: Traversable f | |
=> (f Rev -> Rev) | |
-> f Double | |
-> IO (Double, f Double) | |
grad op x = do | |
-- reset the state | |
writeIORef state state0 | |
-- vars for input | |
inputs <- mapM (\x -> do var <- addTape Nullary; pure (Rev var x)) x | |
-- compute the primal | |
Rev varout out <- pure $ op inputs | |
State nextvar tape <- readIORef state | |
-- number of expressions | |
let n = nextvar | |
assert "hmm" (varout == n - 1) | |
-- compute the gradient | |
-- array to keep the partial derivatives | |
arr <- thawPrimArray (replicatePrimArray n 0.0) 0 n | |
-- dout/dout = 1 | |
writePrimArray arr varout 1.0 | |
let add var val = readPrimArray arr var >>= \val1 -> writePrimArray arr var (val + val1) | |
-- push the contribution of this var into its children | |
let update varz cell = do | |
dres_dz <- readPrimArray arr varz | |
case cell of | |
Nullary -> pure () | |
-- z = op x | |
-- write dres_dx into arr[varx] | |
Unary varx pz_px -> | |
add varx (dres_dz * pz_px) | |
-- z = op x y | |
-- write dres_dx into arr[varx] | |
-- write dres_dy into arr[vary] | |
Binary varx pz_px vary pz_py -> | |
add varx (dres_dz * pz_px) >> | |
add vary (dres_dz * pz_py) | |
-- for every expression N..0 | |
zipWithM_ update [n - 1, n - 2 .. 0] tape | |
dres_dx <- mapM (\(Rev varx _) -> readPrimArray arr varx) inputs | |
pure (out, dres_dx) | |
main = do | |
(grad1 exp 12) >>= print | |
(grad1 exp 10) >>= print | |
(grad (sum . map exp) [12, 10]) >>= print | |
putStrLn "Manual 'ad'" | |
print $ funbkw [1,2,3] | |
putStrLn "Automatic 'ad'" | |
(grad fun1 [1,2,3]) >>= print | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment