Skip to content

Instantly share code, notes, and snippets.

@amalloy
Forked from VictorTaelin/implementing_fft.md
Created May 3, 2023 20:56
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save amalloy/a2d092b70c3b9fa8b5423a9cddd23f85 to your computer and use it in GitHub Desktop.
Save amalloy/a2d092b70c3b9fa8b5423a9cddd23f85 to your computer and use it in GitHub Desktop.
Implementing complex numbers and FFT with just datatypes (no floats)

Implementing complex numbers and FFT with just datatypes (no floats)

In this article, I'll explain why implementing numbers with just algebraic datatypes is desirable. I'll then talk about common implementations of FFT (Fast Fourier Transform) and why they hide inherent inefficiencies. I'll then show how to implement integers and complex numbers with just algebraic datatypes, in a way that is extremely simple and elegant. I'll conclude by deriving a pure functional implementation of complex FFT with just datatypes, no floats.

Why implement numbers with ADTs?

For most programmers, "real numbers" are a given: they just use floats and call it a day. But, in some cases, it doesn't work well, and I'm not talking about precision issues. When trying to prove statements on real numbers in proof assistants like Coq, we can't use doubles, we must formalize reals using datatypes, which can be very hard. For me, the real issue arises when I'm trying to implement optimal algorithms on HVM. To do that, I need functions to fuse, which is a fancy way of saying that (+ 2) . (+ 2) "morphs into" (+ 4) during execution - yet, machine floats block that mechanism. Because of that, I've been trying to come up with an elegant way to implement numbers with just datatypes. For natural numbers, it is easy: we can just implement them as bitstrings:

-- O stands for bit 0
-- I stands for bit 1
-- E stands for end-of-string
data Nat = O Nat | I Nat | E deriving (Show)

-- Applies 'f' n times to 'x'.
rep :: Nat -> (a -> a) -> a -> a
rep (O x) f = rep x (\x -> f (f x))
rep (I x) f = rep x (\x -> f (f x)) . f
rep E     f = id

-- Increments a Nat.
inc :: Nat -> Nat
inc E     = O E
inc (O x) = I x
inc (I x) = O (inc x)

-- Adds two Nats.
add :: Nat -> Nat -> Nat
add a = rep a inc

main :: IO ()
main = do
  let a = I (I (I (I (I (I (I (I (I (I (I (I (I (I (I (I (I (I (I (I (I (I (I (I (I (I E)))))))))))))))))))))))))
  let b = O (O (O (O (O (O (O (O (O (O (O (O (O (O (O (O (O (O (O (O (O (O (O (O (O (O E)))))))))))))))))))))))))
  print (add a b)

The program above implements addition by repeated increment, and is the simplest numerical example where fusion makes a big difference, being exponential on Haskell and linear on HVM. But this isn't useful in practice, as we already have fast addition algorithms like add-carry. The question, then, becomes: is it possible to explore optimality to implement numeric algorithms that actually beat existing ones in practice? The idea is that we could implement numbers with datatypes, which are, usually, much slower than floats, but, if the resulting asymptotics are improved, then this would easily pay off. One of the most successful numeric algorithms ever is the FFT. How would an "optimal FFT" look like? Let's find out!

What is wrong with the textbook FFT

Designing optimal functions is extremely tricky, because it takes very little to interrupt fusion and get an exponential slowdown. For example, if we change the inc Haskell function above to let the Nat grow in size:

inc :: Nat -> Nat
inc E     = I E -- allows the bitstring to grow
inc (O x) = I x
inc (I x) = O (inc x)

Then, the fact we're now using the "I" constructor twice will prevent us from fusing the corresponding function on the HVM. A single character will downgrade it from linear to exponential! There are bazillion ways to ruin fusion. Passing a state as an argument, placing a lambda in the wrong location, using an argument more than once... in fact, it feels like anything that detours from a "direct linear recursion" will refuse to fuse. With that in mind, let's examine the textbook FFT:

Note: you do not need to understand FFT to read this article - just view this as a random snippet we want to optimize.

-- FFT receives a polynomial P, represented as a list of complex coefficients,
-- and returns the evaluation of P on len(P) points of the unit circle. Example:
-- - input  = [A,B,C,D,E,F,G,H]
-- - output = [P(e(0/8)),P(e(1/8)),P(e(2/8)),P(e(3/8)),P(e(4/8)),P(e(5/8)),P(e(6/8)),P(e(7/8))]
--          where P(x) = A + Bx¹ + Cx² + Dx³ + Ex⁴ + Fx⁵ + Gx⁶ + Hx⁷
--                e(x) = e^(2πxi)

-- When len=1, we must eval P(e^(0*2πi)) where P(x) = A. That's just A.
fft [x] = [x]

-- When len>1...
fft xs = pts where

  -- If len=8, we'll eval P = A + Bx¹ + Cx² + Dx³ + Ex⁴ + Fx⁵ + Gx⁶ + Hx⁷
  -- on x ∈ [e(0/8),e(1/8),e(2/8),e(3/8),e(4/8),e(5/8),e(6/8),e(7/8)]
  len = length xs

  -- We first split it on two polynomials, based on even/odd exponents:
  eve = split 0 xs -- EVE = A   + Cx² + Ex⁴ + Gx⁶
  odd = split 1 xs -- ODD = Bx¹ + Dx³ + Fx⁵ + Hx⁷

  -- We then call FFT recursively on EVE and ODD
  pt0 = fft eve -- PT0 = A + Cx¹ + Ex² + Gx³ evaluated on x ∈ [0/4𝜏,1/4𝜏,2/4𝜏,3/4𝜏] (by induction)
                -- PT0 = A + Cx² + Ex⁴ + Gx⁶ evaluated on x ∈ [0/8𝜏,1/8𝜏,2/8𝜏,3/8𝜏] (by equivalence)
  pt1 = fft odd -- PT1 = B + Dx¹ + Fx² + Hx³ evaluated on x ∈ [0/4𝜏,1/4𝜏,2/4𝜏,3/4𝜏] (by induction)
                -- PT1 = B + Dx² + Fx⁴ + Hx⁶ evaluated on x ∈ [0/8𝜏,1/8𝜏,2/8𝜏,3/8𝜏] (by equivalence)

  -- We then compute e(x) for each angle, which are the "twiddle factors"
  twi = [cPol (2 * pi * fromIntegral k / fromIntegral len) | k <- [0..len `div` 2 - 1]]

  -- We then multiply PT1 by the tiddle factors
  pt2 = zipWith cMul twi pt1 -- PT2 = Bx¹ + Dx³ + Fx⁵ + Hx⁷ evaluated on x ∈ [0/4𝜏,1/4𝜏,2/4𝜏,3/4𝜏]

  -- Finally, we obtain all points as PT0 +- PT2
  -- This exploits the symmetry of polynomials
  ptl = zipWith cAdd pt0 pt2 -- PTL = A + Bx¹ + Cx² + Dx³ + Ex⁴ + Fx⁵ + Gx⁶ + Hx⁷ evaluated on [0/8𝜏,1/8𝜏,2/8𝜏,3/8𝜏]
  ptr = zipWith cSub pt0 pt2 -- PTR = A + Bx¹ + Cx² + Dx³ + Ex⁴ + Fx⁵ + Gx⁶ + Hx⁷ evaluated on [4/8𝜏,5/8𝜏,6/8𝜏,7/8𝜏]

  -- The result just combines PTL and PTR to get all unit circle points
  pts = ptl ++ ptr

Complete code.

So, what is wrong with this? Well, when it comes to optimality, everything. First, it traverses the list just to compute its length. Then, it copies the entire list 2 more times to split even/odd indices. Then, it traverses the odds list twice. Then it copies it. Then it re-generates the same twiddle factors over and over. Then it traverses everything a few more times with zips. All these things block fusion, and, if that wasn't bad enough, it performs arithmetic on machine floats on its its inner loop, which removes any remaining hope we could have. So, how can it be improved?

Improving the structure of FFT

The first insight is to replace lists by balanced binary trees, and store elements on nodes such that, by starting from any element of the tree, walking up to the root, and annotating the branches we passed through as 0 for left and 1 for right, we'll get the binary representation of the index. For example:

[0, 1, 2, 3, 4, 5, 6, 7]

Is represented as:

(B (B (B (L 0) (L 4)) (B (L 2) (L 6))) (B (B (L 1) (L 5)) (B (L 3) (L 7))))

If we start from 6, and move up until the root, we'll pass through right (1), right (1) and left (0) branches of a node, which agrees with the fact 110 is 6 in binary.

Storing elements that way has two benefits. First, we can now split a tree into even/odd indices in O(1) by just taking the left/right branches, removing the no need to call the "split" function, which is O(N), on every recursive call. For example, by taking the left branch of that tree, we get:

(B (B (L 0) (L 4)) (B (L 2) (L 6)))

Which corresponds to the list [0, 2, 4, 6], i.e., the even indices of the original list. And if we take the left branch again, we get:

(B (L 0) (L 4))

Which corresponds to the list [0, 4], which is, once again, the even indices of the list above. The second benefit is that we're now able to replace the 3 calls to zipWith by a single call to mix, which combines evens, odds and the twiddle factors in a single pass. Here is the improved version:

data Complex = C Double Double deriving Show
data Nat     = E | O Nat | I Nat deriving Show
data Tree a  = L a | B (Tree a) (Tree a) deriving Show

fft :: Nat -> Tree Complex -> Tree Complex
fft ang (B eve odd) =
  let pt0 = fft (O ang) eve
      pt1 = fft (O ang) odd
  in mix ang pt0 pt1
fft ang (L x) = L x

mix :: Nat -> Tree Complex -> Tree Complex -> Tree Complex
mix ang (B ax ay) (B bx by) =
  let ptl = mix (O ang) ax bx
      ptr = mix (I ang) ay by
  in B ptl ptr
mix ang (L pt0) (L pt1) =
  let pt2 = cMul pt1 (cPol (2 * pi * natVal ang 0 / 2 ** fromIntegral (natLen ang + 1)))
      ptl = L (cAdd pt0 pt2)
      ptr = L (cSub pt0 pt2)
  in B ptl ptr

Complete code.

The way it works is similar, with all unnecessary work removed. The fft function receives a tree with coefficients, and calls itself recursively on the left and right branches, which corresponds to even and odd coefficients on the original polynomial. Then, it combines the set of recursive points in a single pass by calling the mix function. Finally, it doesn't generate a list of twiddle factors. Instead, it stores an angle (as a natural number, ranging from 0 to N, where N is the number of points) which is then used to generate the twiddle factor at the bottom of the recursion.

While this algorithm is less pedagogical and readable, it is fully linear and way less contrived for the runtime. It is in a healthy shape for optimality, but there is still a problem: it represents complex numbers using floats, thus requires machine numbers that don't fuse. Of course, we could perform FFT over other fields, but in some cases we really need complex FFT.

Implementing integers with datatypes

Our ultimate goal is to come up with an implementation of Complex that can be used on FFT, and that is based purely on ADTs, i.e., no native floats, in such a way that is simple and direct enough to possibly fuse. Let's begin with a simpler goal: integers. Let's recall how we implemented Nat and inc:

data Nat = O Nat | I Nat | E deriving (Show)

inc :: Nat -> Nat
inc E     = O E
inc (O x) = I x
inc (I x) = O (inc x)

So, what about Int? We could try implementing it by using a Nat and a sign:

-- sign = True  correspond to  0,  1,  2,  3, ...
-- sign = False correspond to -1, -2, -3, -4, ...
data Int = Int Bool Nat

intInc :: Int -> Int
intInc (Int True  n) = Int True (inc n)
intInc (Int False n)
    | isMinusOne n = Int True natZero
    | otherwise    = Int False (dec n)

But this is actually a bad idea since, by pattern-matching on the boolean to treat cases separately we inhibit fusion. Also, calling isMinusOne makes it non-linear, which also inhibits fusion. In general, this is exactly the shape of code that does not fuse. We're looking for something more uniform, that doesn't separate the sign.

Fortunatelly, there is a pretty elegant way to do it: balanced ternary. It is similar to ternary numbers, expect its digits are not 0,1,2, but -1,0,1. The -1 digit is usually written as T, so, for example, the string 11T represents 9 + 3 - 1, which is 11. Amazingly, all integers can be uniquely represented in this system. Let's count:

num   trits
--- | -----
−13 |   𝖳𝖳𝖳
−12 |   𝖳𝖳0
−11 |   𝖳𝖳1
−10 |   𝖳0𝖳
 −9 |   𝖳00
 −8 |   𝖳01
 −7 |   𝖳1𝖳
 −6 |   𝖳10
 −5 |   𝖳11
 −4 |    𝖳𝖳 
 −3 |    𝖳0
 −2 |    𝖳1
 −1 |     𝖳
  0 |     0
  1 |     1
  2 |    1𝖳
  3 |    10
  4 |    11
  5 |   1𝖳𝖳
  6 |   1𝖳0
  7 |   1𝖳1
  8 |   10𝖳
  9 |   100
 10 |   101
 11 |   11𝖳
 12 |   110
 13 |   111

Beautiful, isn't it? Its symmetric proporties and simple arithmetic are quite important for fusion. Here is an implementation of Int and inc:

-- T stands for digit -1
-- O stands for digit  0
-- I stands for digit  1
-- E stands for end-of-string
data Int = T Int | O Int | I Int | E

inc :: Int -> Int
inc E     = E
inc (T x) = O x
inc (O x) = I x
inc (I x) = t (Inc x)

As you can imagine, this representation of Int, and its arithmetic operations, can be implemented on HVM in a way that fuses nicely, solving part of the problem. But what about real and complex numbers?

Implementing complex numbers with datatypes

Sadly, it took me a long time to figure out the proper way to do Int, which is astronomically simpler than these. Just to think of the complexity of floats (which include sign, mantissa, exponent, NaNs...), Dedekind Cuts, Cauchy Sequences and the like would give us the feeling we'll never find a representation of real numbers that is as nice and uniform as the Int type above. Yet, for the sake of FFT, there is some light at the end of the tunnel.

The insight is that we don't actually need all complex numbers; we just need enough of them to represent roots of unity. In fact, if we wanted to perform FFT on lists of 2 elements, integers would be enough for us, since we only need to evaluate polynomials on 2 complex points, [1, -1], which are both just ints!

G0

But what if there were 4 elements? In this case, we would actually need 4 points: [1, i, -1, -i]... which is beyond the set of integers. But there is an common extension of integers that could help: Gaussian Integers, which add the square root of -1, i, to Int. It forms sort of a quantized grid over the complex plane, as follows:

We could implement that type as follows:

-- Represents (A + B*i)
data Gauss = G Int Int

This would give us 4 roots of unity, and, thus, the ability to perform FFT on lists of 4 elements!

What about 8 elements though? Well, the idea here is to simply keep going, and add a new constant, j, which is the square root of i. Once we do that, we need 4 ints to represent a number; that's because ij is also part of the set, so any number can be represented as A + B*i + C*j + D*ij. We can implement this extended Gaussian Integer as:

-- Represents (A + B*i + C*j + D*ij)
data ExtGauss = G Int Int Int Int

This would give us 8 roots of unity, and, thus, the ability to perform FFT on lists of 8 elements.

And it would allow us to represent "fractional" complex numbers with just integers. For example, the number 4.652 + 7.816 * i can be approximated as 1 + 2*i + 3*j + 4*ij, which can be constructed as:

num :: ExtGauss
num = G 1 2 3 4

As you can see, we can keep going and generalize this to make the set G(N), with G(0) being Integers, G(1) being Gaussian Integers, G(2) being Extended Gaussian Integers, and so on. To implement it, we can use a tree:

data Tree a = V a | B (Tree a) (Tree a)
type GN     = Tree Int

For G(3), we have 16 roots of unity, and can perform FFT on lists of 16 elements:

To add two complex numbers of GN, we just do so pairwise:

add :: GN -> GN -> GN
add (L z)   (L w)   = L (z + w)
add (B a b) (B c d) = B (add a c) (add b d)

But what about multiplication? A very elegant algorithm, found by T6 on HOC's discord, allows us to multiply a GN number by bases like i, j, k, etc., with just rotations and negation:

rot :: GN -> GN
rot (L z)   = L (-z)
rot (B a b) = B (rot b) a

This simple function, if applied on a GN element, will multiply it by its smallest base. So, for example, on G(3), rot performs multiplication by k. On G(2), it multiplies by j. On G(1), it multiplies by i. On G(0), which is just a scalar, it negates the number. Very nice! We can use Rot to multiply two GN elements as follows (once again, thanks T6):

mul :: GN -> GN -> GN
mul (L z)   (L w)   = L (z * w)
mul (B a b) (B c d) = B (add (mul a c) (rot (mul b d))) (add (mul b c) (mul a d))

Which is, again, very elegant. But, for the sake of FFT, we don't actually need to perform multiplication of two arbitrary numbers during the algorithm; we just need to multiply by twiddle factors on the upper side of the unit circle. As such, we can use this simplified algorithm, which receives an angle on the unit circle, represented by a Nat, and multiplies it by the respective point:

mul :: Nat -> GN -> GN
mul E     x         = x
mul (O p) (L x)     = id  (L x)
mul (I p) (L x)     = rot (L x)
mul (O p) (B ax ay) = id  (B (mul p ax) (mul p ay))
mul (I p) (B ax ay) = rot (B (mul p ax) (mul p ay))

So, for example, mul (I (I (I O))) x will multiply i by ijk, which is approximately -0.923 + 0.382i. This completely removes the need for computing cos, sin, e^ix, and the like, simplifying the pt2 computation from:

pt2 = cMul pt1 (cPol (2 * pi * natVal ang 0 / 2 ** fromIntegral (natLen ang + 1)))

To just:

let pt2 = mul (ang E) pt1

In other words, all the complex floating point arithmetic, including trigonometric and exponentiation functions, are replaced by the mul function above, which is just a simple recursive pass over the tree structure!

Implementing FFT with datatypes

After all this, we're finally able to implement a complete FFT algorithm, in 40 lines of Haskell code, with just plain ADTs:

data Nat    = E | O Nat | I Nat
data Tree a = L a | B (Tree a) (Tree a)
type GN     = Tree Int

add :: GN -> GN -> GN
add (L x)     (L y)     = L (x + y)
add (B ax ay) (B bx by) = B (add ax bx) (add ay by)

sub :: GN -> GN -> GN
sub (L x)     (L y)     = L (x - y)
sub (B ax ay) (B bx by) = B (sub ax bx) (sub ay by)

rot :: GN -> GN
rot (L z)   = L (-z)
rot (B x y) = B (rot y) x

mul :: Nat -> GN -> GN
mul E     x         = x
mul (O p) (L x)     = L x
mul (I p) (L x)     = rot (L x)
mul (O p) (B ax ay) = B (mul p ax) (mul p ay)
mul (I p) (B ax ay) = rot (B (mul p ax) (mul p ay))

fft :: (Nat -> Nat) -> Tree GN -> Tree GN
fft ang (L x)   = L x
fft ang (B x y) =
  let pt0 = fft (ang.O) x
      pt1 = fft (ang.O) y
  in mix ang pt0 pt1

mix :: (Nat -> Nat) -> Tree GN -> Tree GN -> Tree GN
mix ang (B ax ay) (B bx by) =
  let ptl = mix (ang.O) ax bx
      ptr = mix (ang.I) ay by
  in B ptl ptr
mix ang (L pt0) (L pt1) =
  let pt2 = mul (ang E) pt1
      ptl = L (add pt0 pt2)
      ptr = L (sub pt0 pt2)
  in B ptl ptr

Yes, this the a complete implementation! Note in this version I'm actually using the primitive Haskell Int, but, as we've stablished, it can be implemented with just ADTs via balanced ternary. Here is a complete Haskell file which performs FFT using GNs, with some helper functions to convert it from and to lists of normal complex numbers, for visualization.

So, how does this behave on the HVM? I don't know. I've just finished implementing it on Haskell and will spend some time trying to adjust it for the HVM in the future. There are many things to tweak. For example, sub isn't necessary and can easily be removed; Nat angles can be replaced by a twiddle multiplier; mix could be simplified, perhaps. But, as is, this gets rid of most of the inefficiencies with common FFT, and, as far as I know, is the cleanest version of FFT in a pure functional sense. If you like it, feel encouraged to join the Higher Order Community to talk about it and ask any questions. Thanks for reading!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment