Skip to content

Instantly share code, notes, and snippets.

@ian-ross
Created December 31, 2013 14:13
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
Star You must be signed in to star a gist
Save ian-ross/8197357 to your computer and use it in GitHub Desktop.
{-# LANGUAGE ScopedTypeVariables, GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeSynonymInstances, FlexibleInstances #-}
module PrimeFFT where
import Prelude hiding (length, sum, map, zipWith, (++), foldr, foldr1, or, and,
concat, concatMap, replicate, scanl, scanl1, scanr, null,
init, last, tail, head, filter, reverse, product,
maximum, zip, dropWhile, enumFromTo, all, take)
import Data.List (sort, nub)
import qualified Prelude as P
import qualified Data.Map as M
import Data.Complex
import Data.Vector
-- Packages used only for testing.
import Data.Bits
import Control.Applicative ((<$>))
import Test.QuickCheck
-- Typing Vector (Complex Double) or Vector (Vector (Complex Double))
-- gets old quickly.
type CD = Complex Double
type VCD = Vector CD
type VVCD = Vector (Vector CD)
-- | FFT and inverse FFT drivers for Rader's algorithm.
raderFFT, raderIFFT :: VCD -> VCD
raderFFT = raderFFT' 1 1
raderIFFT v = raderFFT' (-1) (1.0 / (fromIntegral $ length v)) v
-- | Rader's algorithm for prime-length complex-to-complex Fast
-- Fourier Transform.
raderFFT' :: Int -> Double -> VCD -> VCD
raderFFT' sign scale xs
| isPrime p = map ((scale :+ 0) *) $ generate p $ \idx -> case idx of
0 -> sum xs
_ -> xs ! 0 + convmap M.! idx
| otherwise = error "non-prime length in raderFFT"
where p = length xs
p1 = p - 1
-- ^ Convolution length.
p1pad = if p1 == 2^(log2 p1)
then p1
else 2 ^ (1 + log2 (2 * p1 - 3))
-- ^ Convolution length padded to next greater power of two.
g = primitiveRoot p
-- ^ Group generator.
ig = invModN p g
-- ^ Inverse group generator.
as = backpermute xs $ iterateN p1 (\n -> (g * n) `mod` p) 1
-- ^ Input values permuted according to group generator
-- indexing.
pad = p1pad - p1
-- ^ Padding size.
apad = generate p1pad $
\idx -> if idx == 0 then as ! 0
else if idx > pad then as ! (idx - pad) else 0
-- ^ Permuted input vector padded to next greater power of two
-- size for fast convolution.
iidxs = iterateN p1 (\n -> (ig * n) `mod` p) 1
-- ^ Index vector based on inverse group generator ordering.
w = omega p
bs = backpermute (map ((w ^^) . (sign *)) $ enumFromTo 0 p1) iidxs
-- ^ Root of unity powers based on inverse group generator
-- indexing.
bpad = generate p1pad (\idx -> bs ! (idx `mod` p1))
-- ^ Root of unity powers cyclically repeated to make vector
-- of next power of two length for fast convolution.
conv = ifft $ zipWith (*) (fft apad) (fft bpad)
-- ^ FFT-based convolution calculation.
convmap = M.fromList $ toList $ zip iidxs conv
-- ^ Map constructed to enable inversion of inverse group
-- generator index ordering for output.
-- | Determine primitive roots modulo n.
--
-- From Wikipedia (https://en.wikipedia.org/wiki/Primitive_root_modulo_n):
--
-- No simple general formula to compute primitive roots modulo n is
-- known. There are however methods to locate a primitive root that
-- are faster than simply trying out all candidates. If the
-- multiplicative order of a number m modulo n is equal to phi(n) (the
-- order of Z_n^x), then it is a primitive root. In fact the converse
-- is true: if m is a primitive root modulo n, then the multiplicative
-- order of m is phi(n). We can use this to test for primitive roots.
-- [Here, phi(n) is Euler's totient function, and Z_n^x is the
-- multiplicative group of integers modulo n.]
--
-- First, compute phi(n). Then determine the different prime factors
-- of phi(n), say p1, ..., pk. Now, for every element m of Z_n^x,
-- compute
--
-- m^(phi(n) / pi) mod n for i = 1, ..., k
--
-- using a fast algorithm for modular exponentiation such as
-- exponentiation by squaring. A number m for which these k results
-- are all different from 1 is a primitive root.
--
-- [In our case, n is restricted to being prime, and phi(p) = p - 1
-- for prime p.]
--
primitiveRoot :: Int -> Int
primitiveRoot p
| isPrime p =
let tot = p - 1
-- ^ Euler's totient function for prime values.
totpows = map (tot `div`) $ fromList $ nub $ toList $ factors tot
-- ^ Powers to check.
check n = all (/=1) $ map (expt p n) totpows
-- ^ All powers are different from 1 => primitive root.
in fromIntegral $ head $ dropWhile (not . check) $ fromList [1..p-1]
| otherwise = error "Attempt to take primitive root of non-prime value"
-- | Fast exponentation modulo n by squaring.
--
expt :: Int -> Int -> Int -> Int
expt n b pow = fromIntegral $ go pow
where bb = fromIntegral b
nb = fromIntegral n
go :: Int -> Integer
go p
| p == 0 = 1
| p `mod` 2 == 1 = (bb * go (p - 1)) `mod` nb
| otherwise = let h = go (p `div` 2) in (h * h) `mod` nb
-- | Find inverse element in multiplicative integer group modulo n.
--
invModN :: Int -> Int -> Int
invModN n g = head $ filter (\iv -> (g * iv) `mod` n == 1) $ enumFromTo 1 (n-1)
-- | Prime sieve from Haskell wiki.
--
primes :: Integral a => [a]
primes = 2 : primes'
where primes' = sieve [3, 5 ..] 9 primes'
sieve (x:xs) q ps@ ~(p:t)
| x < q = x : sieve xs q ps
| True = sieve [n | n <- xs, rem n p /= 0] (P.head t^2) t
-- | Naive primality testing.
--
isPrime :: Integral a => a -> Bool
isPrime n = n `P.elem` P.takeWhile (<= n) primes
-- | Naive prime factorisation.
--
factors :: Integral a => a -> Vector a
factors n = fromList $ go n primes
where go cur pss@(p:ps)
| cur == p = [p]
| cur `mod` p == 0 = p : go (cur `div` p) pss
| otherwise = go cur ps
-- Testing and debugging code.
-- | QuickCheck generator for prime values. Inefficient...
--
newtype Prime a = Prime { getPrime :: a }
deriving (Eq, Ord, Show, Read, Num, Integral, Real, Enum)
instance (Integral a, Ord a, Arbitrary a) => Arbitrary (Prime a) where
arbitrary = (Prime . (\n -> P.head $ P.dropWhile (< n) primes)) `fmap`
(arbitrary `suchThat` (> 1))
-- Test code for primitive root determination.
prop_primitive_root ((Prime n) :: Prime Int) =
primitiveRootTest n $ primitiveRoot n
primitiveRootTest :: Int -> Int -> Bool
primitiveRootTest p root
| isPrime p = (sort $ toList $ pows) == [1..p-1]
| otherwise = error "Attempt to take primitive root of non-prime value"
where pows = generate (p - 1) (expt p root)
-- Best way to use the primitive root calculation QuickCheck property:
--
-- verboseCheckWith (stdArgs { maxSize=25 }) prop_primitive_root
-- Clean up number display.
defuzz :: VCD -> VCD
defuzz = map (\(r :+ i) -> df r :+ df i)
where df x = if abs x < 1.0E-6 then 0 else x
-- Check FFT against DFT.
check :: VCD -> (Double, VCD)
check v = let diff = defuzz $ zipWith (-) (raderFFT v) (dft v)
in (maximum $ map magnitude diff, diff)
-- Check FFT-inverse FFT round-trip.
icheck :: VCD -> (Double, VCD)
icheck v = let diff = defuzz $ zipWith (-) v (raderIFFT $ raderFFT v)
in (maximum $ map magnitude diff, diff)
-- QuickCheck property for FFT vs. DFT testing.
prop_dft_vs_fft (v :: VCD) = fst (check v) < 1.0E-4
-- QuickCheck property for inverse FFT round-trip testing.
prop_ifft (v :: VCD) = maximum (map magnitude diff) < 1.0E-4
where diff = zipWith (-) v (raderIFFT $ raderFFT v)
-- Prime length arbitrary vectors.
instance Arbitrary VCD where
arbitrary = do
len <- elements $ P.takeWhile (< 500) primes
fromList <$> vectorOf len arbitrary
-- Check all 0/1 vectors of a given length.
repcheck :: (VCD -> (Double, VCD)) -> Int -> Maybe Int
repcheck checker n = if null chks
then Nothing
else Just $ snd $ head chks
where vs = map tobits $ enumFromN (0::Int) (2^n)
tobits i = generate n (\j -> if testBit i j then 1:+0 else 0:+0)
chks = dropWhile (\(d,_) -> d < 1.0E-6) $
zip (map (fst . checker) vs) (enumFromN 0 (2^n))
-- BASIC FFT CODE FOLLOWS...
i :: Complex Double
i = 0 :+ 1
omega :: Int -> Complex Double
omega n = cis (2 * pi / fromIntegral n)
dft, idft :: VCD -> VCD
dft = dft' 1 1
idft v = dft' (-1) (1.0 / (fromIntegral $ length v)) v
dft' :: Int -> Double -> VCD -> VCD
dft' sign scale h = generate bigN (((scale :+ 0) *) . doone)
where bigN = length h
w = omega bigN
doone n = sum $
zipWith (*) h $ generate bigN (\k -> w^^(sign*n*k))
fft, ifft :: VCD -> VCD
fft = fft' 1 1
ifft v = fft' (-1) (1.0 / (fromIntegral $ length v)) v
fft' :: Int -> Double -> VCD -> VCD
fft' sign scale h =
if n <= 2
then dft' sign scale h
else map ((scale :+ 0) *) $ recomb $ backpermute h (bitrev n)
where n = length h
recomb = foldr (.) id $ map dl $ iterateN (log2 n) (`div` 2) n
dl m v = let w = omega m
m2 = m `div` 2
ds = map ((w ^^) . (sign *)) $ enumFromN 0 m2
doone v = let v0 = slice 0 m2 v
v1 = zipWith (*) ds $ slice m2 m2 v
in (zipWith (+) v0 v1) ++ (zipWith (-) v0 v1)
in concat $ P.map doone $ slicevecs m v
slicevecs m v = P.map (\i -> slice (i * m) m v) [0..n `div` m - 1]
bitrev :: Int -> Vector Int
bitrev n =
let nbits = log2 n
bs = generate nbits id
onebit i r b = if testBit i b then setBit r (nbits - 1 - b) else r
in generate n (\i -> foldl' (onebit i) 0 bs)
log2 :: Int -> Int
log2 1 = 0
log2 n = 1 + log2 (n `div` 2)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment