Instantly share code, notes, and snippets.

Last active Nov 29, 2022
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 {-# LANGUAGE TypeSynonymInstances #-} data Dual d = D Float d deriving Show type Float' = Float diff :: (Dual Float' -> Dual Float') -> Float -> Float' diff f x = y' where D y y' = f (D x 1) class VectorSpace v where zero :: v add :: v -> v -> v scale :: Float -> v -> v instance VectorSpace Float' where zero = 0 add = (+) scale = (*) instance VectorSpace d => Num (Dual d) where (+) (D u u') (D v v') = D (u+v) (add u' v') (*) (D u u') (D v v') = D (u*v) (add (scale u v') (scale v u')) negate (D u u') = D (negate u) (scale (-1) u') signum (D u u') = D (signum u) zero abs (D u u') = D (abs u) (scale (signum u) u') fromInteger n = D (fromInteger n) zero instance VectorSpace d => Fractional (Dual d) where (/) (D u u') (D v v') = D (u/v) (scale (1/v^2) (add (scale v u') (scale (-u) v'))) fromRational n = D (fromRational n) zero instance VectorSpace d => Floating (Dual d) where pi = D pi zero exp (D u u') = D (exp u) (scale (exp u) u') log (D u u') = D (log u) (scale (1/u) u') sin (D u u') = D (sin u) (scale (cos u) u') cos (D u u') = D (cos u) (scale (-sin u) u') sinh (D u u') = D (sinh u) (scale (cosh u) u') cosh (D u u') = D (cosh u) (scale (sinh u) u')

# Automatic Differentiation in 38 lines of Haskell

(See discussion on Hacker News)

You can now differentiate (almost1) any differentiable hyperbolic, polynomial, exponential, and/or trigonometric function. Let's use the polynomial $f(x) = 2x^3 + 3x^2 + 4x + 2$, whose straightforward derivative, $f'(x) = 6x^2 + 6x + 4$, can be used to verify the AD program.

λ> f x = 2 * x^3 + 3 * x^2 + 4 * x + 2 -- our polynomial
λ> f 10
2342
λ> diff f 10 -- evaluate df/dx with x=10
664.0
λ> 2*3 * 10^2 + 3*2 * 10 + 4 -- verify derivative at 10
664

We can also compose functions:

λ> f x = 2 * x^2 + 3 * x + 5
λ> f2 = tanh . exp . sin . f
λ> f2 0.25
0.5865376368439258
λ> diff f2 0.25
1.6192873

And differentiate high-dimensional functions, such as $f: \mathbb{R}^m \rightarrow \mathbb{R}^n$, by doing manually what diff is doing:

λ> f x y z = 2 * x^2 + 3 * y + sin z        -- f: R^3 -> R
λ> f (D 3 1) (D 4 1) (D 5 1) :: Dual Float' -- call f with dual numbers, set derivative to 1
D 29.041077 15.283662
λ> f x y z = (2 * x^2, 3 * y + sin z)       -- f: R^3 -> R^2
λ> f (D 3 1) (D 4 1) (D 5 1) :: (Dual Float', Dual Float')
(D 18.0 12.0,D 11.041076 3.2836623)

Or get partial derivatives by setting only the sensitivities we want as dual numbers:

λ> f x y z = 2 * x^2 + 3 * y + sin z -- f: R^3 -> R
λ> f (D 3 1) 4 5 :: Dual Float'
D 29.041077 12.0

If you want to learn more about how this works, read the paper by Conal M. Elliott2 or watch the talk, titled "Provably correct, asymptotically efficient, higher-order reverse-mode automatic differentiation" by Simon Peyton Jones himself3, or read their paper4 by the same name. There's also a package named ad which implements this in a usable way. This gist is merely to understand the most basic form of it. Additionally, there's Andrej Karpathy's micrograd written in Python.

## Footnotes

1. Only the inverse hyperbolic functions aren't yet implemented in the Floating instance

### benwaffle commented Sep 17, 2022

The derivative of $f(x) = 2x^3 + 3x^2 + 4x + 2$ should be $f'(x) = 6x^2 + 6x + 4$

### ttesmer commented Sep 17, 2022

The derivative of f(x)=2x3+3x2+4x+2 should be f′(x)=6x2+6x+4

You're right, thanks! Careless mistake..

### SridharRamesh commented Sep 18, 2022 • edited

On line 34, the derivative of log(u) should surely be 1/u rather than log(u).

### ttesmer commented Sep 18, 2022

On line 34, the derivative of log(u) should surely be 1/u rather than log(u).

Thanks, I missed that one! Fixed it now.

### SridharRamesh commented Sep 18, 2022

On line 22, it seems like you should be writing (negate u) rather than scale (-1) u [in the same way that on line 20, you write (u + v) rather than (add u v)].

The code works correctly as written because Float' happens to be a type synonym for Float, and thus VectorSpace methods automatically apply to u as well. But the very fact that you bothered making a different name for Float' suggests you want to be preserving some distinction which this is ignoring.

This is no big deal, of course. I'm just mentioning it in passing. Overall, this is great!

### ImreMD commented Sep 18, 2022

I would also refer to : https://youtu.be/q1DUKEOUoxA very hands on presentation.

### sradc commented Sep 18, 2022

Very cool, and a useful reference! Inspired me to do an Autodiff in N lines of Python gist. 😄

### runeksvendsen commented Sep 18, 2022

type Float' = Float

Why this type synonym?