Skip to content

Instantly share code, notes, and snippets.

@adamse
Created May 18, 2023 14:25
Show Gist options
  • Save adamse/59bb016d1f101ae2e020894d94091300 to your computer and use it in GitHub Desktop.
Save adamse/59bb016d1f101ae2e020894d94091300 to your computer and use it in GitHub Desktop.
#!/usr/bin/env cabal
{- cabal:
build-depends: base, primitive
-}
{-# language BangPatterns #-}
import Data.IORef (IORef, newIORef, writeIORef, readIORef, atomicModifyIORef)
import System.IO.Unsafe (unsafePerformIO)
import Data.Primitive.PrimArray (MutablePrimArray, newPrimArray, writePrimArray, readPrimArray, replicatePrimArray, thawPrimArray)
import Control.Monad (zipWithM_, when, mapM)
fun x y z = x * y + exp z
-- dfun/dx = y
-- dfun/dy = x
-- dfun/dz = exp z
fun1 [x,y,z] =
let xy = x * y
expz = exp z
res = xy + expz
in res
-- chain rule
-- f = f(a, b, c)
-- Df = a' * D_a(f) + b' * D_b(f) + c' * D_c(f)
-- D(f + g) = f' * D_f(f + g) + g' * D_g(f + g)
-- = f' + g'
-- D_f(f + g) = 1 + 0 = 1
-- D_g(f + g) = 0 + 1 = 1
-- D(f * g) = f' * D_f(f * g) + g' * D_g(f * g)
-- = f'*g + g'*f
-- D_f(f * g) = g
-- D_g(f * g) = f
-- D(exp f) = f' * exp f
-- D_f(exp f) = exp f
-- * reverse mode ad
--
-- https://web.archive.org/web/20161224122037/https://justindomke.wordpress.com/2009/03/24/a-simple-explanation-of-reverse-mode-automatic-differentiation/
funbkw [x, y, z] =
let xy = x * y
expz = exp z
res = xy + expz
dres_dres = 1
-- for expressions j where xy is an input: res
dres_dxy = dres_dres * pres_pxy where pres_pxy = 1
-- for expressions j where expz is an input: res
dres_dexpz = dres_dres * pres_pexpz where pres_pexpz = 1
-- for expressions j where x is an input: xy
dres_dx = dres_dxy * pxy_px where pxy_px = y
-- for expressions j where y is an input: xy
dres_dy = dres_dxy * pxy_py where pxy_py = x
-- for expressions j where z is an input: expz
dres_dz = dres_dexpz * pexpz_dz where pexpz_dz = expz
-- what information would we need to calc this automatically?
-- when evalling expz we need to push something into z to note that it was used?
in (res, [dres_dx, dres_dy, dres_dz])
gun x = 2 * x ^ 2
gun1 x =
let x2 = x ^ 2
res = 2 * x2
in res
gunbkw x =
let x2 = x ^ 2 -- write that we
res = 2 * x2
dres_dres = 1
dres_dx2 = dres_dres * pres_px2 where pres_px2 = 2
dres_x = dres_dx2 * px2_px where px2_px = 2 * x
-- dres_dx = 2 * 2 * x
in (res, dres_x)
{-
the tape strategy works like this: every operation gets a number, for example for
exp (Rev var1 val1)
- we generate a new var2
- we write to the tape Unary var1 (pexp_pval1 = exp val1), to note the partial derivative
- return (Rev var2 (exp val1))
then on the way out when we're calculating the grad
we for each expression (var) add its contribution to its children
-}
-- toy implementation, not thread safe
type Var = Int
data Cell
= Nullary
| Unary !Var !Double -- contribution to the var from this unary computation
| Binary
!Var !Double -- contribution to the left hand side from this binary computation
!Var !Double -- contribution to the right hand side from this binary computation
deriving (Show)
data State = State !Var [Cell]
state0 = State 0 []
-- this would be a thread local variable
{-# NOINLINE state #-}
state = unsafePerformIO $ newIORef state0
addTape t = atomicModifyIORef state (\(State current ts) ->
let !next = succ current
!ts' = t:ts
!new = State next ts'
in (new, current))
data Rev
= Const !Double
| Rev !Var !Double
nilop
:: Double
-> Rev
nilop x = Const x
unop
:: (Double -> Double) -- op x y =
-> (Double -> Double) -- dop/dx (partial derivative)
-> Rev -> Rev
unop op dop (Const x) = Const (op x)
unop op dop_dx (Rev varx x) = Rev (unsafePerformIO (addTape (Unary varx (dop_dx x)))) (op x)
binop
:: (Double -> Double -> Double) -- op x y =
-> (Double -> Double -> Double) -- dop/dx (partial derivative)
-> (Double -> Double -> Double) -- dop/dy (partial derivative)
-> Rev -> Rev -> Rev
binop op dop_dx dop_dy (Const x) (Const y) = Const (op x y)
binop op dop_dx dop_dy (Rev varx x) (Const y) = unop (\x -> op x y) (\x -> dop_dx x y) (Rev varx x)
binop op dop_dx dop_dy (Const x) (Rev vary y) = unop (op x) (dop_dy x) (Rev vary y)
binop op dop_dx dop_dy (Rev varx x) (Rev vary y) = Rev (unsafePerformIO (addTape (Binary varx (dop_dx x y) vary (dop_dy x y)))) (op x y)
instance Num Rev where
(+) = binop (+) (\_ _ -> 1) (\_ _ -> 1)
(-) = binop (-) (\_ _ -> 1) (\_ _ -> negate 1)
(*) = binop (*) (\_ y -> y) (\x _ -> x)
abs = unop abs signum
signum = error "signum"
fromInteger = nilop . fromInteger
instance Fractional Rev where
recip = unop recip (\x -> negate (recip (x * x)))
(/) x y = x * recip y
fromRational = nilop . fromRational
instance Floating Rev where
exp = unop exp exp
pi = nilop pi
log = unop log recip
sqrt = unop sqrt (\x -> recip (2 * sqrt x))
(**) = binop (**) (\x y -> y * (x ** (y - 1))) (\x y -> log x * (x ** y))
sin = unop sin cos
cos = unop cos (negate . sin)
asin = unop asin (\x -> recip (sqrt (1 - x*x)))
acos = unop acos (\x -> negate (recip (sqrt (1 - x * x))))
atan = unop atan (\x -> recip (1 + x * x))
sinh = unop sinh cosh
cosh = unop cosh sinh
asinh = unop asinh (\x -> recip (sqrt (x * x + 1)))
acosh = unop acosh (\x -> recip (sqrt (x * x - 1)))
atanh = unop atanh (\x -> recip (1 - x * x))
assert m bool = when (not bool) (error m)
{-# noinline grad1 #-}
grad1
:: (Rev -> Rev)
-> Double
-> IO (Double, Double)
grad1 op x = do
(res, [grad]) <- grad (\[x] -> op x) [x]
pure (res, grad)
{-# noinline grad #-}
grad
:: Traversable f
=> (f Rev -> Rev)
-> f Double
-> IO (Double, f Double)
grad op x = do
-- reset the state
writeIORef state state0
-- vars for input
inputs <- mapM (\x -> do var <- addTape Nullary; pure (Rev var x)) x
-- compute the primal
Rev varout out <- pure $ op inputs
State nextvar tape <- readIORef state
-- number of expressions
let n = nextvar
assert "hmm" (varout == n - 1)
-- compute the gradient
-- array to keep the partial derivatives
arr <- thawPrimArray (replicatePrimArray n 0.0) 0 n
-- dout/dout = 1
writePrimArray arr varout 1.0
let add var val = readPrimArray arr var >>= \val1 -> writePrimArray arr var (val + val1)
-- push the contribution of this var into its children
let update varz cell = do
dres_dz <- readPrimArray arr varz
case cell of
Nullary -> pure ()
-- z = op x
-- write dres_dx into arr[varx]
Unary varx pz_px ->
add varx (dres_dz * pz_px)
-- z = op x y
-- write dres_dx into arr[varx]
-- write dres_dy into arr[vary]
Binary varx pz_px vary pz_py ->
add varx (dres_dz * pz_px) >>
add vary (dres_dz * pz_py)
-- for every expression N..0
zipWithM_ update [n - 1, n - 2 .. 0] tape
dres_dx <- mapM (\(Rev varx _) -> readPrimArray arr varx) inputs
pure (out, dres_dx)
main = do
(grad1 exp 12) >>= print
(grad1 exp 10) >>= print
(grad (sum . map exp) [12, 10]) >>= print
putStrLn "Manual 'ad'"
print $ funbkw [1,2,3]
putStrLn "Automatic 'ad'"
(grad fun1 [1,2,3]) >>= print
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment