Skip to content

Instantly share code, notes, and snippets.

@bitonic
Created March 20, 2021 11:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bitonic/035d0a4b7d69dbfbda7cf9954bd90b0d to your computer and use it in GitHub Desktop.
Save bitonic/035d0a4b7d69dbfbda7cf9954bd90b0d to your computer and use it in GitHub Desktop.
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}
import Data.List (transpose)
data Forward a = Forward { _value :: !a, _grad :: !a }
deriving (Show, Eq)
lift :: Num a => a -> Forward a
lift a = Forward { _value = a, _grad = 0 }
var :: Num a => a -> Forward a
var a = Forward a 1
chain :: Num a => (a -> a) -> (a -> a) -> Forward a -> Forward a
chain f f' (Forward x d) = Forward (f x) (f' x * d)
instance (Fractional a, Num a) => Num (Forward a) where
fromInteger x = Forward (fromInteger x) 0
(Forward x1 d1) + (Forward n2 d2) = Forward (x1 + n2) (d1 + d2)
(Forward x1 d1) * (Forward n2 d2) = Forward (x1 * n2) (x1 * d2 + n2 * d1)
negate x = x * lift (-1)
abs = chain abs (\x -> abs x / x)
signum (Forward x _) = (Forward (signum x) 0)
instance Fractional a => Fractional (Forward a) where
fromRational x = Forward (fromRational x) 0
(Forward x1 d1) / (Forward x2 d2) = Forward (x1 / x2) ((x2 * d1 - x1 * d2) / (x2 * x2))
instance Floating a => Floating (Forward a) where
pi = lift pi
exp = chain exp exp
log = chain log recip
sin = chain sin cos
cos = chain cos (\x -> - (sin x))
tan = chain tan (\x -> let secx = recip (cos x) in secx * secx)
asin = chain asin (\x -> 1 / sqrt (1 - x*x))
acos = chain acos (\x -> -1 / sqrt (1 - x*x))
-- | For a function $f : R -> R^m$, gives us the result (first element),
-- and df / dx.
grad :: Num a => (Forward a -> [Forward a]) -> a -> ([a], [a])
grad f x = unzip (map (\Forward{..} -> (_value, _grad)) (f (var x)))
-- | For a function $f : R^n -> R^m$, gives us the Jacobian
jacobian :: Num a => ([Forward a] -> [Forward a]) -> [a] -> [[a]]
jacobian f xs =
-- each invocation gives us a column vector
transpose (go [] xs)
where
go before = \case
[] -> []
x : after ->
map _grad (f (map lift (reverse before) ++ [var x] ++ map lift after)) :
go (x : before) after
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment