Skip to content

Instantly share code, notes, and snippets.

@martindale
Forked from ttesmer/AD.hs
Created September 18, 2022 02:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save martindale/8edd8abf3de6b4828b9ce9ee7393f832 to your computer and use it in GitHub Desktop.
Save martindale/8edd8abf3de6b4828b9ce9ee7393f832 to your computer and use it in GitHub Desktop.
Automatic Differentiation in 38 lines of Haskell using Operator Overloading and Dual Numbers. Inspired by conal.net/papers/beautiful-differentiation
{-# LANGUAGE TypeSynonymInstances #-}
data Dual d = D Float d
type Float' = Float
diff :: (Dual Float' -> Dual Float') -> Float -> Float'
diff f x = y'
where D y y' = f (D x 1)
class VectorSpace v where
zero :: v
add :: v -> v -> v
scale :: Float -> v -> v
instance VectorSpace Float' where
zero = 0
add = (+)
scale = (*)
instance VectorSpace d => Num (Dual d) where
(+) (D u u') (D v v') = D (u+v) (add u' v')
(*) (D u u') (D v v') = D (u*v) (add (scale u v') (scale v u'))
negate (D u u') = D (scale (-1) u) (scale (-1) u')
signum (D u u') = D (signum u) zero
abs (D u u') = D (abs u) (scale (signum u) u')
fromInteger n = D (fromInteger n) zero
instance VectorSpace d => Fractional (Dual d) where
(/) (D u u') (D v v') = D (u/v) (scale (1/v^2) (add (scale v u') (scale (-u) v')))
fromRational n = D (fromRational n) zero
instance VectorSpace d => Floating (Dual d) where
pi = D pi zero
exp (D u u') = D (exp u) (scale (exp u) u')
log (D u u') = D (log u) (scale (log u) u')
sin (D u u') = D (sin u) (scale (cos u) u')
cos (D u u') = D (cos u) (scale (-sin u) u')
sinh (D u u') = D (sinh u) (scale (cosh u) u')
cosh (D u u') = D (cosh u) (scale (sinh u) u')

Automatic Differentiation in 38 lines of Haskell

You can now differentiate (almost1) any differentiable hyperbolic, polynomial, exponential, and/or trigonometric function that only takes one input (for now). Let's use the polynomial $f(x) = 2x^3 + 3x^2 + 4x + 2$, whose straightforward derivative, $f'(x) = 6x^2 + 6x + 4$, can be used to verify the AD program.

λ> f x = 2 * x^3 + 3 * x^2 + 4 * x + 2 -- our polynomial
λ> f 10
2342
λ> diff f 10 -- evaluate df/dx with x=10
664.0
λ> 2*3 * 10^2 + 3*2 * 10 + 4 -- verify derivative at 10
664

We can also compose functions:

λ> f x = 2 * x^2 + 3 * x + 5
λ> f2 = tanh . exp . sin . f
λ> f2 0.25
0.5865376368439258
λ> diff f2 0.25
1.6192873

If you want to learn more about how this works, read the paper by Conal M. Elliott2 or watch the talk, titled "Provably correct, asymptotically efficient, higher-order reverse-mode automatic differentiation" by Simon Peyton Jones himself3, or read their paper4 by the same name. There's also a package named ad which implements this in a usable way. This gist is merely to understand the most basic form of it. Additionally, there's Andrej Karpathy's micrograd written in Python.

Footnotes

  1. Only the inverse hyperbolic functions aren't yet implemented in the Floating instance

  2. http://conal.net/papers/beautiful-differentiation/beautiful-differentiation-long.pdf

  3. https://www.youtube.com/watch?v=EPGqzkEZWyw

  4. https://richarde.dev/papers/2022/ad/higher-order-ad.pdf

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment