Last active
August 9, 2017 19:37
-
-
Save mrange/da57f972b3dfdfb44f28fd340841586c to your computer and use it in GitHub Desktop.
FFT
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module Simple = | |
open System | |
open System.Numerics | |
let pi = Math.PI | |
let tau = 2.*pi | |
let twiddle a = Complex.FromPolarCoordinates(1., a) | |
let rec fft = function | |
| [] -> [] | |
| [x] -> [x] | |
| x -> | |
x | |
|> List.mapi (fun i c -> i % 2 = 0, c) | |
|> List.partition fst | |
|> fun (even, odd) -> fft (List.map snd even), fft (List.map snd odd) | |
||> List.mapi2 (fun i even odd -> | |
let btf = odd * twiddle (-tau * (float i / float x.Length)) | |
even + btf, even - btf) | |
|> List.unzip | |
||> List.append | |
// For examples and tests see: https://gist.github.com/mrange/da57f972b3dfdfb44f28fd340841586c | |
// Inspired by: http://fssnip.net/dC/title/fast-Fourier-transforms-FFT- | |
// and: https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm | |
// Sacrifices idioms for performance | |
module FourierTransform = | |
open System | |
open System.Numerics | |
let maxSize = 4096 | |
let pi = Math.PI | |
let tau = 2.*pi | |
module internal Details = | |
let isPowerOf2 n = (n &&& n - 1) = 0 | |
let ilog2 n = | |
if n < 2 then failwith "n must be greater than 1" | |
if not (isPowerOf2 n) then failwith "n must be a power of 2" | |
let rec loop n c s = | |
let t = 1 <<< c | |
if t = n then | |
c | |
elif t > n then | |
loop n (c - s) (s >>> 1) | |
else | |
loop n (c + s) (s >>> 1) | |
loop n 16 8 | |
let twiddle a = Complex.FromPolarCoordinates(1., a) | |
let twiddles = | |
let unfolder c = | |
if c < 2*maxSize then | |
let vs = Array.init (c / 2) (fun i -> twiddle (-tau * float i / float c)) | |
Some (vs, c*2) | |
else | |
None | |
Array.unfold unfolder 1 | |
let rec loop n2 tws s c f t = | |
if c > 2 then | |
let c2 = c >>> 1 | |
let struct (t, f) = loop n2 tws (s <<< 1) c2 f t | |
if s > 1 then | |
for j = 0 to c2 - 1 do | |
let off = s*j | |
let off2= off <<< 1; | |
let w = Array.get tws off | |
for i = 0 to s - 1 do | |
let e = Array.get f (i + off2 + 0) | |
let o = Array.get f (i + off2 + s) | |
let a = w*o | |
Array.set t (i + off + 0) (e + a) | |
Array.set t (i + off + n2) (e - a) | |
else | |
for j = 0 to c2 - 1 do | |
let w = Array.get tws j | |
let e = Array.get f (2*j + 0) | |
let o = Array.get f (2*j + s) | |
let a = w*o | |
Array.set t (j + 0) (e + a) | |
Array.set t (j + n2) (e - a) | |
struct (f, t) | |
elif c = 2 then | |
for i = 0 to s - 1 do | |
let e = Array.get f (i + 0) | |
let o = Array.get f (i + s) | |
let a = o | |
Array.set t (i + 0) (e + a) | |
Array.set t (i + n2) (e - a) | |
struct (f, t) | |
else | |
struct (t, f) | |
open Details | |
let dft (vs : Complex []) = | |
let l = vs.Length | |
let am = tau / float l | |
let rec loop s j i = | |
if j < l then | |
let v = vs.[j] | |
let n = v*twiddle (-float i * float j * am) | |
loop (s + n) (j + 1) i | |
else | |
s | |
Array.init l (loop Complex.Zero 0) | |
let fft (vs : Complex []) : Complex [] = | |
let n = vs.Length | |
let ln = ilog2 n | |
let vs0 = Array.copy vs | |
let vs1 = Array.zeroCreate n | |
let struct (_, t) = Details.loop (n >>> 1) twiddles.[ln] 1 n vs0 vs1 | |
t | |
module Tests = | |
open FsCheck | |
type Samples = Samples of Complex [] | |
type FftGenerators = | |
static member Complex () : Arbitrary<Complex> = | |
let f = | |
Arb.generate<float> | |
|> Gen.map (fun v -> | |
if Double.IsNaN v then | |
0. | |
elif Double.IsPositiveInfinity v then | |
1. | |
elif Double.IsNegativeInfinity v then | |
-1. | |
else | |
v % 1000. | |
) | |
let c = | |
Gen.constant (fun r i -> Complex (r, i)) | |
<*> f | |
<*> f | |
Arb.fromGen c | |
static member Samples () : Arbitrary<Samples> = | |
let s = | |
gen { | |
let! n = Gen.choose (1, 3) |> Gen.map (fun i -> 1 <<< i) | |
let! c = Gen.arrayOfLength n Arb.generate<Complex> | |
return Samples c | |
} | |
Arb.fromGen s | |
type FftTests = | |
static member ``Test ilog2`` (i : int) = | |
let e = (abs i) % 29 + 2 | |
let v = 1 <<< e | |
let a = Details.ilog2 v | |
e = a | |
static member ``Simple test`` (s : Samples) = | |
let (Samples vs) = s | |
let e = dft vs | |
let a = fft vs | |
if e.Length = a.Length then | |
let rec loop s i = | |
if i < e.Length then | |
let d = e.[i] - a.[i] | |
loop (s + d*d) (i + 1) | |
else | |
s | |
let s = loop Complex.Zero 0 | |
let r = s.Magnitude / float e.Length < 1E-10 | |
if not r then printfn "E:%A\nA:%A" e a | |
r | |
else | |
false | |
let run () = | |
let config = | |
{ Config.Quick with | |
Arbitrary = typeof<FftGenerators> :: Config.Quick.Arbitrary | |
MaxTest = 1000 | |
MaxFail = 1000 | |
} | |
Check.All<FftTests> config | |
// now () returns current time in milliseconds since start | |
let now : unit -> int64 = | |
let sw = System.Diagnostics.Stopwatch () | |
sw.Start () | |
fun () -> sw.ElapsedMilliseconds | |
// time estimates the time 'action' repeated a number of times | |
let time repeat action = | |
let inline cc i = System.GC.CollectionCount i | |
let v = action () | |
System.GC.Collect (2, System.GCCollectionMode.Forced, true) | |
let bcc0, bcc1, bcc2 = cc 0, cc 1, cc 2 | |
let b = now () | |
for i in 1..repeat do | |
action () |> ignore | |
let e = now () | |
let ecc0, ecc1, ecc2 = cc 0, cc 1, cc 2 | |
v, (e - b), ecc0 - bcc0, ecc1 - bcc1, ecc2 - bcc2 | |
open System | |
open System.Numerics | |
[<EntryPoint>] | |
let main argv = | |
let trim (c : Complex) = | |
let trim (f : float) = if abs f < 1E-14 then 0. else f | |
Complex (trim c.Real, trim c.Imaginary) | |
FourierTransform.Tests.run () | |
let repeat = 1000 | |
let l = 1024 | |
let input = | |
[| for x in 0..(l - 1) -> | |
let a = FourierTransform.tau*float x/float l | |
let v = cos a + cos (2.*a) | |
let c = Complex (v / (float l / 2.), 0.) |> trim | |
c | |
|] | |
let inputList = input |> List.ofArray | |
let output1 = | |
inputList | |
|> Simple.fft | |
|> List.map trim | |
let output2 = | |
input | |
|> FourierTransform.dft | |
|> Array.map trim | |
let output3 = | |
input | |
|> FourierTransform.fft | |
|> Array.map trim | |
if l < 64 then | |
printfn "Input : %A" input | |
printfn "Simple FFT : %A" output1 | |
printfn "Faster DFT : %A" output2 | |
printfn "Faster FFT : %A" output3 | |
let testCases = | |
[| | |
"Simple FFT" , fun () -> Simple.fft inputList |> ignore | |
// "Faster DFT" , fun () -> FourierTransform.dft input |> ignore | |
"Faster FFT" , fun () -> FourierTransform.fft input |> ignore | |
|] | |
for name, action in testCases do | |
printfn "Running %s %d times ..." name repeat | |
let _, ms, cc0, cc1, cc2 = time repeat action | |
printfn " it took %d ms and (%d, %d, %d) CC" ms cc0 cc1 cc2 | |
0 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment