Skip to content

Instantly share code, notes, and snippets.

@ttesmer
Last active October 29, 2024 15:35
Show Gist options
  • Save ttesmer/948df432cf46ec6db8c1e83ab59b1b21 to your computer and use it in GitHub Desktop.
Save ttesmer/948df432cf46ec6db8c1e83ab59b1b21 to your computer and use it in GitHub Desktop.
Automatic Differentiation in 38 lines of Haskell using Operator Overloading and Dual Numbers. Inspired by conal.net/papers/beautiful-differentiation
{-# LANGUAGE TypeSynonymInstances #-}
data Dual d = D Float d deriving Show
type Float' = Float
diff :: (Dual Float' -> Dual Float') -> Float -> Float'
diff f x = y'
where D y y' = f (D x 1)
class VectorSpace v where
zero :: v
add :: v -> v -> v
scale :: Float -> v -> v
instance VectorSpace Float' where
zero = 0
add = (+)
scale = (*)
instance VectorSpace d => Num (Dual d) where
(+) (D u u') (D v v') = D (u+v) (add u' v')
(*) (D u u') (D v v') = D (u*v) (add (scale u v') (scale v u'))
negate (D u u') = D (negate u) (scale (-1) u')
signum (D u u') = D (signum u) zero
abs (D u u') = D (abs u) (scale (signum u) u')
fromInteger n = D (fromInteger n) zero
instance VectorSpace d => Fractional (Dual d) where
(/) (D u u') (D v v') = D (u/v) (scale (1/v^2) (add (scale v u') (scale (-u) v')))
fromRational n = D (fromRational n) zero
instance VectorSpace d => Floating (Dual d) where
pi = D pi zero
exp (D u u') = D (exp u) (scale (exp u) u')
log (D u u') = D (log u) (scale (1/u) u')
sin (D u u') = D (sin u) (scale (cos u) u')
cos (D u u') = D (cos u) (scale (-sin u) u')
sinh (D u u') = D (sinh u) (scale (cosh u) u')
cosh (D u u') = D (cosh u) (scale (sinh u) u')

Automatic Differentiation in 38 lines of Haskell

(See discussion on Hacker News)

You can now differentiate (almost1) any differentiable hyperbolic, polynomial, exponential, and/or trigonometric function. Let's use the polynomial $f(x) = 2x^3 + 3x^2 + 4x + 2$, whose straightforward derivative, $f'(x) = 6x^2 + 6x + 4$, can be used to verify the AD program.

λ> f x = 2 * x^3 + 3 * x^2 + 4 * x + 2 -- our polynomial
λ> f 10
2342
λ> diff f 10 -- evaluate df/dx with x=10
664.0
λ> 2*3 * 10^2 + 3*2 * 10 + 4 -- verify derivative at 10
664

We can also compose functions:

λ> f x = 2 * x^2 + 3 * x + 5
λ> f2 = tanh . exp . sin . f
λ> f2 0.25
0.5865376368439258
λ> diff f2 0.25
1.6192873

And differentiate high-dimensional functions, such as $f: \mathbb{R}^m \rightarrow \mathbb{R}^n$, by doing manually what diff is doing:

λ> f x y z = 2 * x^2 + 3 * y + sin z        -- f: R^3 -> R
λ> f (D 3 1) (D 4 1) (D 5 1) :: Dual Float' -- call `f` with dual numbers, set derivative to 1
D 29.041077 15.283662
λ> f x y z = (2 * x^2, 3 * y + sin z)       -- f: R^3 -> R^2
λ> f (D 3 1) (D 4 1) (D 5 1) :: (Dual Float', Dual Float')
(D 18.0 12.0,D 11.041076 3.2836623)

Or get partial derivatives by setting only the sensitivities we want as dual numbers:

λ> f x y z = 2 * x^2 + 3 * y + sin z -- f: R^3 -> R
λ> f (D 3 1) 4 5 :: Dual Float'
D 29.041077 12.0

If you want to learn more about how this works, read the paper by Conal M. Elliott2 or watch the talk, titled "Provably correct, asymptotically efficient, higher-order reverse-mode automatic differentiation" by Simon Peyton Jones himself3, or read their paper4 by the same name. There's also a package named ad which implements this in a usable way. This gist is merely to understand the most basic form of it. Additionally, there's Andrej Karpathy's micrograd written in Python.

Footnotes

  1. Only the inverse hyperbolic functions aren't yet implemented in the Floating instance

  2. http://conal.net/papers/beautiful-differentiation/beautiful-differentiation-long.pdf

  3. https://www.youtube.com/watch?v=EPGqzkEZWyw

  4. https://richarde.dev/papers/2022/ad/higher-order-ad.pdf

@sradc
Copy link

sradc commented Sep 18, 2022

Very cool, and a useful reference! Inspired me to do an Autodiff in N lines of Python gist. 😄

@runeksvendsen
Copy link

type Float' = Float

Why this type synonym?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment