Skip to content

Instantly share code, notes, and snippets.

@VictorTaelin
Last active May 5, 2023 05:23
Show Gist options
  • Save VictorTaelin/7a7aba09275f88d159d9fb3ecf948860 to your computer and use it in GitHub Desktop.
Save VictorTaelin/7a7aba09275f88d159d9fb3ecf948860 to your computer and use it in GitHub Desktop.
data Complex = C Double Double deriving Show
cScale :: Double -> Complex -> Complex
cScale s (C ar ai) = C (ar * s) (ai * s)
cAdd :: Complex -> Complex -> Complex
cAdd (C ar ai) (C br bi) = C (ar + br) (ai + bi)
cSub :: Complex -> Complex -> Complex
cSub (C ar ai) (C br bi) = C (ar - br) (ai - bi)
cMul :: Complex -> Complex -> Complex
cMul (C ar ai) (C br bi) = C (ar * br - ai * bi) (ar * bi + ai * br)
cPol :: Double -> Complex
cPol ang = C (cos ang) (sin ang)
split :: Int -> [a] -> [a]
split n [] = []
split 0 (x:xs) = x : split 1 xs
split 1 (x:xs) = split 0 xs
-- FFT receives a polynomial P, represented as a list of complex coefficients,
-- and returns the evaluation of P on len(P) points of the unit circle. Example:
-- - input = [A,B,C,D,E,F,G,H]
-- - output = [P(e(0/8)),P(e(1/8)),P(e(2/8)),P(e(3/8)),P(e(4/8)),P(e(5/8)),P(e(6/8)),P(e(7/8))]
-- where P(x) = A + Bx¹ + Cx² + Dx³ + Ex⁴ + Fx⁵ + Gx⁶ + Hx⁷
-- e(x) = e^(2πxi)
-- When len=1, we must eval P(e^(0*2πi)) where P(x) = A. That's just A.
fft [x] = [x]
-- When len>1...
fft xs = pts where
-- If len=8, we'll eval P = A + Bx¹ + Cx² + Dx³ + Ex⁴ + Fx⁵ + Gx⁶ + Hx⁷
-- on x ∈ [e(0/8),e(1/8),e(2/8),e(3/8),e(4/8),e(5/8),e(6/8),e(7/8)]
len = length xs
-- We first split it on two polynomials, based on even/odd exponents:
eve = split 0 xs -- EVE = A + Cx² + Ex⁴ + Gx⁶
odd = split 1 xs -- ODD = Bx¹ + Dx³ + Fx⁵ + Hx⁷
-- We then call FFT recursively on EVE and ODD
pt0 = fft eve -- PT0 = A + Cx¹ + Ex² + Gx³ evaluated on x ∈ [0/4𝜏,1/4𝜏,2/4𝜏,3/4𝜏] (by induction)
-- PT0 = A + Cx² + Ex⁴ + Gx⁶ evaluated on x ∈ [0/8𝜏,1/8𝜏,2/8𝜏,3/8𝜏] (by equivalence)
pt1 = fft odd -- PT1 = B + Dx¹ + Fx² + Hx³ evaluated on x ∈ [0/4𝜏,1/4𝜏,2/4𝜏,3/4𝜏] (by induction)
-- PT1 = B + Dx² + Fx⁴ + Hx⁶ evaluated on x ∈ [0/8𝜏,1/8𝜏,2/8𝜏,3/8𝜏] (by equivalence)
-- We then compute e(x) for each angle, which are the "twiddle factors"
twi = [cPol (2 * pi * fromIntegral k / fromIntegral len) | k <- [0..len `div` 2 - 1]]
-- We then multiply PT1 by the tiddle factors
pt2 = zipWith cMul twi pt1 -- PT2 = Bx¹ + Dx³ + Fx⁵ + Hx⁷ evaluated on x ∈ [0/8𝜏,1/8𝜏,2/8𝜏,3/8𝜏]
-- Finally, we obtain all points as PT0 +- PT2
-- This exploits the symmetry of polynomials
ptl = zipWith cAdd pt0 pt2 -- PTL = A + Bx¹ + Cx² + Dx³ + Ex⁴ + Fx⁵ + Gx⁶ + Hx⁷ evaluated on [0/8𝜏,1/8𝜏,2/8𝜏,3/8𝜏]
ptr = zipWith cSub pt0 pt2 -- PTR = A + Bx¹ + Cx² + Dx³ + Ex⁴ + Fx⁵ + Gx⁶ + Hx⁷ evaluated on [4/8𝜏,5/8𝜏,6/8𝜏,7/8𝜏]
-- The result just combines PTL and PTR to get all unit circle points
pts = ptl ++ ptr
main :: IO ()
main = do
let c0 = (C 0.0 0.0)
let c1 = (C 1.0 0.0)
let c2 = (C 2.0 0.0)
let c3 = (C 3.0 0.0)
let c4 = (C 4.0 0.0)
let c5 = (C 5.0 0.0)
let c6 = (C 6.0 0.0)
let c7 = (C 7.0 0.0)
print (fft [c0,c1,c2,c3,c4,c5,c6,c7])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment