Rader's algorithm for prime-length FFT (http://www.skybluetrades.net/blog/posts/2013/12/31/data-analysis-fft-9.html)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{-# LANGUAGE ScopedTypeVariables, GeneralizedNewtypeDeriving #-} | |
{-# LANGUAGE TypeSynonymInstances, FlexibleInstances #-} | |
module PrimeFFT where | |
import Prelude hiding (length, sum, map, zipWith, (++), foldr, foldr1, or, and, | |
concat, concatMap, replicate, scanl, scanl1, scanr, null, | |
init, last, tail, head, filter, reverse, product, | |
maximum, zip, dropWhile, enumFromTo, all, take) | |
import Data.List (sort, nub) | |
import qualified Prelude as P | |
import qualified Data.Map as M | |
import Data.Complex | |
import Data.Vector | |
-- Packages used only for testing. | |
import Data.Bits | |
import Control.Applicative ((<$>)) | |
import Test.QuickCheck | |
-- Typing Vector (Complex Double) or Vector (Vector (Complex Double)) | |
-- gets old quickly. | |
type CD = Complex Double | |
type VCD = Vector CD | |
type VVCD = Vector (Vector CD) | |
-- | FFT and inverse FFT drivers for Rader's algorithm. | |
raderFFT, raderIFFT :: VCD -> VCD | |
raderFFT = raderFFT' 1 1 | |
raderIFFT v = raderFFT' (-1) (1.0 / (fromIntegral $ length v)) v | |
-- | Rader's algorithm for prime-length complex-to-complex Fast | |
-- Fourier Transform. | |
raderFFT' :: Int -> Double -> VCD -> VCD | |
raderFFT' sign scale xs | |
| isPrime p = map ((scale :+ 0) *) $ generate p $ \idx -> case idx of | |
0 -> sum xs | |
_ -> xs ! 0 + convmap M.! idx | |
| otherwise = error "non-prime length in raderFFT" | |
where p = length xs | |
p1 = p - 1 | |
-- ^ Convolution length. | |
p1pad = if p1 == 2^(log2 p1) | |
then p1 | |
else 2 ^ (1 + log2 (2 * p1 - 3)) | |
-- ^ Convolution length padded to next greater power of two. | |
g = primitiveRoot p | |
-- ^ Group generator. | |
ig = invModN p g | |
-- ^ Inverse group generator. | |
as = backpermute xs $ iterateN p1 (\n -> (g * n) `mod` p) 1 | |
-- ^ Input values permuted according to group generator | |
-- indexing. | |
pad = p1pad - p1 | |
-- ^ Padding size. | |
apad = generate p1pad $ | |
\idx -> if idx == 0 then as ! 0 | |
else if idx > pad then as ! (idx - pad) else 0 | |
-- ^ Permuted input vector padded to next greater power of two | |
-- size for fast convolution. | |
iidxs = iterateN p1 (\n -> (ig * n) `mod` p) 1 | |
-- ^ Index vector based on inverse group generator ordering. | |
w = omega p | |
bs = backpermute (map ((w ^^) . (sign *)) $ enumFromTo 0 p1) iidxs | |
-- ^ Root of unity powers based on inverse group generator | |
-- indexing. | |
bpad = generate p1pad (\idx -> bs ! (idx `mod` p1)) | |
-- ^ Root of unity powers cyclically repeated to make vector | |
-- of next power of two length for fast convolution. | |
conv = ifft $ zipWith (*) (fft apad) (fft bpad) | |
-- ^ FFT-based convolution calculation. | |
convmap = M.fromList $ toList $ zip iidxs conv | |
-- ^ Map constructed to enable inversion of inverse group | |
-- generator index ordering for output. | |
-- | Determine primitive roots modulo n. | |
-- | |
-- From Wikipedia (https://en.wikipedia.org/wiki/Primitive_root_modulo_n): | |
-- | |
-- No simple general formula to compute primitive roots modulo n is | |
-- known. There are however methods to locate a primitive root that | |
-- are faster than simply trying out all candidates. If the | |
-- multiplicative order of a number m modulo n is equal to phi(n) (the | |
-- order of Z_n^x), then it is a primitive root. In fact the converse | |
-- is true: if m is a primitive root modulo n, then the multiplicative | |
-- order of m is phi(n). We can use this to test for primitive roots. | |
-- [Here, phi(n) is Euler's totient function, and Z_n^x is the | |
-- multiplicative group of integers modulo n.] | |
-- | |
-- First, compute phi(n). Then determine the different prime factors | |
-- of phi(n), say p1, ..., pk. Now, for every element m of Z_n^x, | |
-- compute | |
-- | |
-- m^(phi(n) / pi) mod n for i = 1, ..., k | |
-- | |
-- using a fast algorithm for modular exponentiation such as | |
-- exponentiation by squaring. A number m for which these k results | |
-- are all different from 1 is a primitive root. | |
-- | |
-- [In our case, n is restricted to being prime, and phi(p) = p - 1 | |
-- for prime p.] | |
-- | |
primitiveRoot :: Int -> Int | |
primitiveRoot p | |
| isPrime p = | |
let tot = p - 1 | |
-- ^ Euler's totient function for prime values. | |
totpows = map (tot `div`) $ fromList $ nub $ toList $ factors tot | |
-- ^ Powers to check. | |
check n = all (/=1) $ map (expt p n) totpows | |
-- ^ All powers are different from 1 => primitive root. | |
in fromIntegral $ head $ dropWhile (not . check) $ fromList [1..p-1] | |
| otherwise = error "Attempt to take primitive root of non-prime value" | |
-- | Fast exponentation modulo n by squaring. | |
-- | |
expt :: Int -> Int -> Int -> Int | |
expt n b pow = fromIntegral $ go pow | |
where bb = fromIntegral b | |
nb = fromIntegral n | |
go :: Int -> Integer | |
go p | |
| p == 0 = 1 | |
| p `mod` 2 == 1 = (bb * go (p - 1)) `mod` nb | |
| otherwise = let h = go (p `div` 2) in (h * h) `mod` nb | |
-- | Find inverse element in multiplicative integer group modulo n. | |
-- | |
invModN :: Int -> Int -> Int | |
invModN n g = head $ filter (\iv -> (g * iv) `mod` n == 1) $ enumFromTo 1 (n-1) | |
-- | Prime sieve from Haskell wiki. | |
-- | |
primes :: Integral a => [a] | |
primes = 2 : primes' | |
where primes' = sieve [3, 5 ..] 9 primes' | |
sieve (x:xs) q ps@ ~(p:t) | |
| x < q = x : sieve xs q ps | |
| True = sieve [n | n <- xs, rem n p /= 0] (P.head t^2) t | |
-- | Naive primality testing. | |
-- | |
isPrime :: Integral a => a -> Bool | |
isPrime n = n `P.elem` P.takeWhile (<= n) primes | |
-- | Naive prime factorisation. | |
-- | |
factors :: Integral a => a -> Vector a | |
factors n = fromList $ go n primes | |
where go cur pss@(p:ps) | |
| cur == p = [p] | |
| cur `mod` p == 0 = p : go (cur `div` p) pss | |
| otherwise = go cur ps | |
-- Testing and debugging code. | |
-- | QuickCheck generator for prime values. Inefficient... | |
-- | |
newtype Prime a = Prime { getPrime :: a } | |
deriving (Eq, Ord, Show, Read, Num, Integral, Real, Enum) | |
instance (Integral a, Ord a, Arbitrary a) => Arbitrary (Prime a) where | |
arbitrary = (Prime . (\n -> P.head $ P.dropWhile (< n) primes)) `fmap` | |
(arbitrary `suchThat` (> 1)) | |
-- Test code for primitive root determination. | |
prop_primitive_root ((Prime n) :: Prime Int) = | |
primitiveRootTest n $ primitiveRoot n | |
primitiveRootTest :: Int -> Int -> Bool | |
primitiveRootTest p root | |
| isPrime p = (sort $ toList $ pows) == [1..p-1] | |
| otherwise = error "Attempt to take primitive root of non-prime value" | |
where pows = generate (p - 1) (expt p root) | |
-- Best way to use the primitive root calculation QuickCheck property: | |
-- | |
-- verboseCheckWith (stdArgs { maxSize=25 }) prop_primitive_root | |
-- Clean up number display. | |
defuzz :: VCD -> VCD | |
defuzz = map (\(r :+ i) -> df r :+ df i) | |
where df x = if abs x < 1.0E-6 then 0 else x | |
-- Check FFT against DFT. | |
check :: VCD -> (Double, VCD) | |
check v = let diff = defuzz $ zipWith (-) (raderFFT v) (dft v) | |
in (maximum $ map magnitude diff, diff) | |
-- Check FFT-inverse FFT round-trip. | |
icheck :: VCD -> (Double, VCD) | |
icheck v = let diff = defuzz $ zipWith (-) v (raderIFFT $ raderFFT v) | |
in (maximum $ map magnitude diff, diff) | |
-- QuickCheck property for FFT vs. DFT testing. | |
prop_dft_vs_fft (v :: VCD) = fst (check v) < 1.0E-4 | |
-- QuickCheck property for inverse FFT round-trip testing. | |
prop_ifft (v :: VCD) = maximum (map magnitude diff) < 1.0E-4 | |
where diff = zipWith (-) v (raderIFFT $ raderFFT v) | |
-- Prime length arbitrary vectors. | |
instance Arbitrary VCD where | |
arbitrary = do | |
len <- elements $ P.takeWhile (< 500) primes | |
fromList <$> vectorOf len arbitrary | |
-- Check all 0/1 vectors of a given length. | |
repcheck :: (VCD -> (Double, VCD)) -> Int -> Maybe Int | |
repcheck checker n = if null chks | |
then Nothing | |
else Just $ snd $ head chks | |
where vs = map tobits $ enumFromN (0::Int) (2^n) | |
tobits i = generate n (\j -> if testBit i j then 1:+0 else 0:+0) | |
chks = dropWhile (\(d,_) -> d < 1.0E-6) $ | |
zip (map (fst . checker) vs) (enumFromN 0 (2^n)) | |
-- BASIC FFT CODE FOLLOWS... | |
i :: Complex Double | |
i = 0 :+ 1 | |
omega :: Int -> Complex Double | |
omega n = cis (2 * pi / fromIntegral n) | |
dft, idft :: VCD -> VCD | |
dft = dft' 1 1 | |
idft v = dft' (-1) (1.0 / (fromIntegral $ length v)) v | |
dft' :: Int -> Double -> VCD -> VCD | |
dft' sign scale h = generate bigN (((scale :+ 0) *) . doone) | |
where bigN = length h | |
w = omega bigN | |
doone n = sum $ | |
zipWith (*) h $ generate bigN (\k -> w^^(sign*n*k)) | |
fft, ifft :: VCD -> VCD | |
fft = fft' 1 1 | |
ifft v = fft' (-1) (1.0 / (fromIntegral $ length v)) v | |
fft' :: Int -> Double -> VCD -> VCD | |
fft' sign scale h = | |
if n <= 2 | |
then dft' sign scale h | |
else map ((scale :+ 0) *) $ recomb $ backpermute h (bitrev n) | |
where n = length h | |
recomb = foldr (.) id $ map dl $ iterateN (log2 n) (`div` 2) n | |
dl m v = let w = omega m | |
m2 = m `div` 2 | |
ds = map ((w ^^) . (sign *)) $ enumFromN 0 m2 | |
doone v = let v0 = slice 0 m2 v | |
v1 = zipWith (*) ds $ slice m2 m2 v | |
in (zipWith (+) v0 v1) ++ (zipWith (-) v0 v1) | |
in concat $ P.map doone $ slicevecs m v | |
slicevecs m v = P.map (\i -> slice (i * m) m v) [0..n `div` m - 1] | |
bitrev :: Int -> Vector Int | |
bitrev n = | |
let nbits = log2 n | |
bs = generate nbits id | |
onebit i r b = if testBit i b then setBit r (nbits - 1 - b) else r | |
in generate n (\i -> foldl' (onebit i) 0 bs) | |
log2 :: Int -> Int | |
log2 1 = 0 | |
log2 n = 1 + log2 (n `div` 2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment