Skip to content

Instantly share code, notes, and snippets.

@bitonic
Last active March 22, 2021 12:20
Show Gist options
  • Save bitonic/640f789a6d879c8186ca71b237e633fa to your computer and use it in GitHub Desktop.
Save bitonic/640f789a6d879c8186ca71b237e633fa to your computer and use it in GitHub Desktop.
Simple reverse AD
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE BangPatterns #-}
{-# OPTIONS_GHC -Wall #-}
import Data.IORef
import Data.Reflection
import qualified Data.Vector as V
import Data.Proxy
import System.IO.Unsafe (unsafePerformIO)
import Control.Exception (evaluate)
import qualified Data.Vector.Unboxed.Mutable as VUM
import Control.Monad
import Data.Foldable
import Control.DeepSeq
data UnaryOp =
Sin
| Abs
| Signum
| Exp
| Log
deriving (Eq, Show)
evalUnaryOp :: Floating a => UnaryOp -> a -> a
evalUnaryOp op = case op of
Sin -> sin
Abs -> abs
Signum -> signum
Exp -> exp
Log -> log
unaryGradientWeight :: Floating a => UnaryOp -> a -> a
unaryGradientWeight op x = case op of
Sin -> cos x
Abs -> abs x / x
Signum -> 0
Exp -> exp x
Log -> 1 / x
data BinaryOp =
Plus
| Times
| Divide
deriving (Eq, Show)
evalBinaryOp :: Fractional a => BinaryOp -> a -> a -> a
evalBinaryOp op = case op of
Plus -> (+)
Times -> (*)
Divide -> (/)
binaryGradientWeights :: Floating a => BinaryOp -> a -> a -> (a, a)
binaryGradientWeights op x y = case op of
Plus -> (1, 1)
Times -> (y, x)
Divide -> (1 / y, - x / (y * y))
data Cells a =
Nil
| Lift
{ _liftTail :: Cells a
}
| Unary
{ _unaryIndex :: {-# UNPACK #-} !Int
, _unaryValue :: a
, _unaryOp :: {-# UNPACK #-} !UnaryOp
, _unaryTail :: Cells a
}
| Binary
{ _binaryIndex1 :: {-# UNPACK #-} !Int
, _binaryValue1 :: a
, _binaryIndex2 :: {-# UNPACK #-} !Int
, _binaryValue2 :: a
, _binaryOp :: {-# UNPACK #-} !BinaryOp
, _binaryTail :: Cells a
}
deriving (Eq, Show)
data Head a = Head
{ headCounter :: {-# UNPACK #-} !Int
, headCells :: Cells a
} deriving (Eq, Show)
newtype Tape a = Tape { unTape :: IORef (Head a) }
newTape :: Int -> IO (Tape a)
newTape numVars = Tape <$> newIORef (Head numVars Nil)
data Reverse s a = Reverse
{ _index :: {-# UNPACK #-} !Int
, _value :: a
} deriving (Eq, Show)
newReflect :: Reifies s (Tape a) => Proxy s -> (Cells a -> Cells a) -> Int
newReflect p cell = unsafePerformIO (atomicModifyIORef (unTape (reflect p)) modifyHead)
where
modifyHead (Head count cells) = head' `seq` count' `seq` (head', count)
where
count' = count+1
head' = Head count' (cell cells)
lift :: Reifies s (Tape a) => Proxy s -> a -> Reverse s a
lift p !x = Reverse (newReflect p Lift) x
binary ::
forall s a.
(Floating a, Reifies s (Tape a))
=> BinaryOp
-> Reverse s a
-> Reverse s a
-> Reverse s a
binary op (Reverse ixx x) (Reverse ixy y) =
Reverse (newReflect (Proxy @s) (Binary ixx x ixy y op)) $! evalBinaryOp op x y
unary ::
forall s a.
(Floating a, Reifies s (Tape a))
=> UnaryOp
-> Reverse s a
-> Reverse s a
unary op (Reverse ix x) = Reverse (newReflect (Proxy @s) (Unary ix x op)) $! evalUnaryOp op x
instance (Reifies s (Tape a), Floating a) => Num (Reverse s a) where
fromInteger = lift (Proxy @s) . fromInteger
(+) = binary Plus
(*) = binary Times
abs = unary Abs
signum = unary Signum
negate x = x * lift (Proxy @s) (-1)
instance (Reifies s (Tape a), Floating a) => Fractional (Reverse s a) where
fromRational = lift (Proxy @s) . fromRational
(/) = binary Divide
instance (Reifies s (Tape a), Floating a) => Floating (Reverse s a) where
pi = lift (Proxy @s) pi
sin = unary Sin
exp = unary Exp
log = unary Log
data Cell a =
Var'
| Lift'
| Unary' {-# UNPACK #-} !Int a {-# UNPACK #-} !UnaryOp
| Binary' {-# UNPACK #-} !Int a {-# UNPACK #-} !Int a {-# UNPACK #-} !BinaryOp
deriving (Eq, Show)
-- | Given a function $f : R^n -> R^m$, gives us the result, and the Jacobian
grad ::
forall a.
(VUM.Unbox a, Floating a, NFData a)
=> (forall s. Reifies s (Tape a) => [Reverse s a] -> [Reverse s a])
-> [a]
-> ([a], [[a]])
grad f args = unsafePerformIO $ do
when (length args == 0) $
error "No variables provided"
tape <- newTape (length args)
(result, resultIndices) <-
evaluate (force (reify tape (\(_ :: Proxy s) -> unzip (map (\Reverse{..} -> (_value, _index)) (f (zipWith (Reverse @s) [0..] args))))))
head' <- readIORef (unTape tape)
-- For every index, get the row vector of the jacobian
let
getGradients :: Int -> IO [a]
getGradients resultIndex = do
let cells = V.fromList (replicate (length args) Var' ++ reverse (go (headCells head')))
gradients :: VUM.IOVector a <- VUM.new (V.length cells)
for_ [0..V.length cells - 1] (\ix -> VUM.write gradients ix 0)
-- propagate backwards, fixing the result to gradient 1
VUM.write gradients resultIndex 1
for_ (reverse [0 .. V.length cells - 1]) $ \cellIx -> do
let cell = cells V.! cellIx
cellGrad <- VUM.read gradients cellIx
case cell of
Lift'{} -> return ()
Var'{} -> return ()
Unary' ix x op -> do
VUM.modify gradients ((cellGrad * unaryGradientWeight op x) +) ix
Binary' ixx x ixy y op -> do
let (gradwx, gradwy) = binaryGradientWeights op x y
VUM.modify gradients ((cellGrad * gradwx) +) ixx
VUM.modify gradients ((cellGrad * gradwy) +) ixy
mapM (VUM.read gradients) [0..length args-1]
jacobian <- mapM getGradients resultIndices
return (result, jacobian)
where
go = \case
Nil -> []
Lift cells -> Lift' : go cells
Unary{..} -> Unary' _unaryIndex _unaryValue _unaryOp : go _unaryTail
Binary{..} -> Binary' _binaryIndex1 _binaryValue1 _binaryIndex2 _binaryValue2 _binaryOp : go _binaryTail
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment