Skip to content

Instantly share code, notes, and snippets.

@bellbind
Last active December 19, 2021 12:19
Show Gist options
  • Save bellbind/36b7e71b3203e868c10405a61cdd2d09 to your computer and use it in GitHub Desktop.
Save bellbind/36b7e71b3203e868c10405a61cdd2d09 to your computer and use it in GitHub Desktop.
[WebAssembly] FFT
// $ wat2wasm fft.wat
// $ node --experimental-modules call-fft-dynamic.js
(async function () {
{
const res = fetch(new URL("./fft.wasm", import.meta.url).href);
const imports = {
"./math.js": await import(new URL("./math.js", import.meta.url))
};
const {instance} = await WebAssembly.instantiateStreaming(res, imports);
run("fast", instance.exports);
}
{
//NOTE: `Response` cannot reuse in `WebAssembly.instantiateStreaming`
const res = fetch(new URL("./fft.wasm", import.meta.url).href);
const imports = {
"./math.js": await import(new URL("./math-slow.js", import.meta.url))
};
const {instance} = await WebAssembly.instantiateStreaming(res, imports);
run("slow", instance.exports);
}
})().catch(console.error);
function run(title, {memory, fft, ifft}) {
// example
{
const N = 16, fofs = 0, Fofs = N * 2 * 8, rofs = N * 4 * 8;
const f = new Float64Array(memory.buffer, fofs, N * 2);
const F = new Float64Array(memory.buffer, Fofs, N * 2);
const r = new Float64Array(memory.buffer, rofs, N * 2);
const fr0 = [1,3,4,2, 5,6,2,4, 0,1,3,4, 5,62,2,3];
fr0.forEach((v, i) => {
[f[i * 2], f[i * 2 + 1]] = [v, 0.0];
});
fft(N, fofs, Fofs);
ifft(N, Fofs, rofs);
console.log(`[${title}-fft]`);
for (let i = 0; i < N; i++) {
console.log([F[i * 2], F[i * 2 + 1]]);
}
console.log(`[${title}-ifft]`);
for (let i = 0; i < N; i++) {
console.log([r[i * 2], r[i * 2 + 1]]);
}
}
// example
{
const N = 1024 * 1024;
const fr0 = [...Array(N).keys()].map(i => Math.sin(i) * i);
const f0 = fr0.map(n => [n, 0]);
const BN = N * 2 * 8 * 3, fofs = 0, Fofs = N * 2 * 8, rofs = N * 4 * 8;
while (memory.buffer.byteLength < BN) memory.grow(1);
const f = new Float64Array(memory.buffer, fofs, N * 2);
const F = new Float64Array(memory.buffer, Fofs, N * 2);
const r = new Float64Array(memory.buffer, rofs, N * 2);
fr0.forEach((v, i) => {
[f[i * 2], f[i * 2 + 1]] = [v, 0.0];
});
console.time(`${title}-fft-ifft`);
fft(N, fofs, Fofs);
ifft(N, Fofs, rofs);
console.timeEnd(`${title}-fft-ifft`);
}
}
// $ wat2wasm fft.wat
// $ node --experimental-modules call-fft-dynamic.js
import {promises as fs} from "fs";
(async function () {
{
const buf = await fs.readFile(new URL("./fft.wasm", import.meta.url));
const imports = {
"./math.js": await import(new URL("./math.js", import.meta.url))
};
const {instance} = await WebAssembly.instantiate(buf, imports);
run("fast", instance.exports);
}
{
const buf = await fs.readFile(new URL("./fft.wasm", import.meta.url));
const imports = {
"./math.js": await import(new URL("./math-slow.js", import.meta.url))
};
const {instance} = await WebAssembly.instantiate(buf, imports);
run("slow", instance.exports);
}
})().catch(console.error);
function run(title, {memory, fft, ifft}) {
// example
{
const N = 16, fofs = 0, Fofs = N * 2 * 8, rofs = N * 4 * 8;
const f = new Float64Array(memory.buffer, fofs, N * 2);
const F = new Float64Array(memory.buffer, Fofs, N * 2);
const r = new Float64Array(memory.buffer, rofs, N * 2);
const fr0 = [1,3,4,2, 5,6,2,4, 0,1,3,4, 5,62,2,3];
fr0.forEach((v, i) => {
[f[i * 2], f[i * 2 + 1]] = [v, 0.0];
});
fft(N, fofs, Fofs);
ifft(N, Fofs, rofs);
console.log(`[${title}-fft]`);
for (let i = 0; i < N; i++) {
console.log([F[i * 2], F[i * 2 + 1]]);
}
console.log(`[${title}-ifft]`);
for (let i = 0; i < N; i++) {
console.log([r[i * 2], r[i * 2 + 1]]);
}
}
// example
{
const N = 1024 * 1024;
const fr0 = [...Array(N).keys()].map(i => Math.sin(i) * i);
const f0 = fr0.map(n => [n, 0]);
const BN = N * 2 * 8 * 3, fofs = 0, Fofs = N * 2 * 8, rofs = N * 4 * 8;
while (memory.buffer.byteLength < BN) memory.grow(1);
const f = new Float64Array(memory.buffer, fofs, N * 2);
const F = new Float64Array(memory.buffer, Fofs, N * 2);
const r = new Float64Array(memory.buffer, rofs, N * 2);
fr0.forEach((v, i) => {
[f[i * 2], f[i * 2 + 1]] = [v, 0.0];
});
console.time(`${title}-fft-ifft`);
fft(N, fofs, Fofs);
ifft(N, Fofs, rofs);
console.timeEnd(`${title}-fft-ifft`);
}
}
// $ wat2wasm fft.wat
// $ node --experimental-modules --experimental-wasm-modules call-fft-import.js
import {memory, fft, ifft} from "./fft.wasm";
// example
{
const N = 16, fofs = 0, Fofs = N * 2 * 8, rofs = N * 4 * 8;
const f = new Float64Array(memory.buffer, fofs, N * 2);
const F = new Float64Array(memory.buffer, Fofs, N * 2);
const r = new Float64Array(memory.buffer, rofs, N * 2);
const fr0 = [1,3,4,2, 5,6,2,4, 0,1,3,4, 5,62,2,3];
fr0.forEach((v, i) => {
[f[i * 2], f[i * 2 + 1]] = [v, 0.0];
});
fft(N, fofs, Fofs);
ifft(N, Fofs, rofs);
console.log("[fft]");
for (let i = 0; i < N; i++) {
console.log([F[i * 2], F[i * 2 + 1]]);
}
console.log("[ifft]");
for (let i = 0; i < N; i++) {
console.log([r[i * 2], r[i * 2 + 1]]);
}
}
// example: benchmark
{
const N = 1024 * 1024;
const fr0 = [...Array(N).keys()].map(i => Math.sin(i) * i);
const f0 = fr0.map(n => [n, 0]);
const BN = N * 2 * 8 * 3, fofs = 0, Fofs = N * 2 * 8, rofs = N * 4 * 8;
while (memory.buffer.byteLength < BN) memory.grow(1);
const f = new Float64Array(memory.buffer, fofs, N * 2);
const F = new Float64Array(memory.buffer, Fofs, N * 2);
const r = new Float64Array(memory.buffer, rofs, N * 2);
fr0.forEach((v, i) => {
[f[i * 2], f[i * 2 + 1]] = [v, 0.0];
});
console.time("fft-ifft");
fft(N, fofs, Fofs);
ifft(N, Fofs, rofs);
console.timeEnd("fft-ifft");
}
// bit operations for FFT
function revBit(k, n0) {
const s1 = ((n0 & 0xaaaaaaaa) >>> 1) | ((n0 & 0x55555555) << 1);
const s2 = ((s1 & 0xcccccccc) >>> 2) | ((s1 & 0x33333333) << 2);
const s3 = ((s2 & 0xf0f0f0f0) >>> 4) | ((s2 & 0x0f0f0f0f) << 4);
const s4 = ((s3 & 0xff00ff00) >>> 8) | ((s3 & 0x00ff00ff) << 8);
const s5 = ((s4 & 0xffff0000) >>> 16) | ((s4 & 0x0000ffff) << 16);
return s5 >>> (32 - k);
}
// FFT: Cooley-Tukey FFT
function fftc(N, c, T) {
const k = Math.log2(N);
const r = new Float64Array(N * 2);
for (let i = 0; i < N; i++) {
const i2 = i * 2, rbi2 = revBit(k, i) * 2;
r[i2] = c[rbi2], r[i2 + 1] = c[rbi2 + 1];
}
for (let Nh = 1; Nh < N; Nh *= 2) {
T /= 2;
for (let s = 0; s < N; s += Nh * 2) {
for (let i = 0; i < Nh; i++) {
const li2 = (s + i) * 2, ri2 = li2 + Nh * 2;
const are = r[ri2], aim = r[ri2 + 1];
const bre = Math.cos(T * i), bim = Math.sin(T * i);
const rre = are * bre - aim * bim, rim = are * bim + aim * bre;
const lre = r[li2], lim = r[li2 + 1];
r[li2] = lre + rre, r[li2 + 1] = lim + rim;
r[ri2] = lre - rre, r[ri2 + 1] = lim - rim;
}
}
}
return r;
}
function fft(N, f) {
return fftc(N, f, -2 * Math.PI);
}
function ifft(N, F) {
const r = fftc(N, F, 2 * Math.PI);
for (let i = 0; i < r.length; i++) r[i] /= N;
return r;
}
// example
{
const N = 16;
const fr0 = [1,3,4,2, 5,6,2,4, 0,1,3,4, 5,62,2,3];
const f = new Float64Array(N * 2);
fr0.forEach((v, i) => [f[i * 2], f[i * 2 + 1]] = [v, 0]);
const F = fft(N, f);
const r = ifft(N, F);
console.log("[fft]");
for (let i = 0; i < N; i++) {
console.log([F[i * 2], F[i * 2 + 1]]);
}
console.log("[ifft]");
for (let i = 0; i < N; i++) {
console.log([r[i * 2], r[i * 2 + 1]]);
}
}
// example: benchmark
{
const N = 1024 * 1024;
const fr0 = [...Array(N).keys()].map(i => Math.sin(i) * i);
const f = new Float64Array(N * 2);
fr0.forEach((v, i) => [f[i * 2], f[i * 2 + 1]] = [v, 0]);
console.time("fft-ifft");
const F = fft(N, f);
const r = ifft(N, F);
console.timeEnd("fft-ifft");
}
(module
(import "./math.js" "sin" (func $sin (param $a f64) (result f64)))
(import "./math.js" "cos" (func $cos (param $a f64) (result f64)))
(import "./math.js" "PI" (global $pi f64))
;;(import "./math.js" "debug" (func $dftc (param f64 i32 i32 i32)))
(memory $mem (export "memory") 1 1024) ;; 1-page as 64KB
(func $revbit (param $k i32) (param $n i32) (result i32)
(local $s i32)
(i32.or
(i32.shr_u (i32.and (local.get $n) (i32.const 0xaaaaaaaa)) (i32.const 1))
(i32.shl (i32.and (local.get $n) (i32.const 0x55555555)) (i32.const 1)))
local.set $s
(i32.or
(i32.shr_u (i32.and (local.get $s) (i32.const 0xcccccccc)) (i32.const 2))
(i32.shl (i32.and (local.get $s) (i32.const 0x33333333)) (i32.const 2)))
local.set $n
(i32.or
(i32.shr_u (i32.and (local.get $n) (i32.const 0xf0f0f0f0)) (i32.const 4))
(i32.shl (i32.and (local.get $n) (i32.const 0x0f0f0f0f)) (i32.const 4)))
local.set $s
(i32.or
(i32.shr_u (i32.and (local.get $s) (i32.const 0xff00ff00)) (i32.const 8))
(i32.shl (i32.and (local.get $s) (i32.const 0x00ff00ff)) (i32.const 8)))
local.set $n
(i32.or
(i32.shr_u (i32.and (local.get $n) (i32.const 0xffff0000)) (i32.const 16))
(i32.shl (i32.and (local.get $n) (i32.const 0x0000ffff)) (i32.const 16)))
local.set $s
(i32.shr_u (local.get $s) (i32.sub (i32.const 32) (local.get $k)))
return
)
(func $fftc (param $t f64) (param $n i32) (param $in i32) (param $out i32)
(local $k i32)
(local $i i32)
(local $rbiin i32)
(local $nh i32)
(local $s i32)
(local $lrei i32) (local $limi i32)
(local $rrei i32) (local $rimi i32)
(local $rad f64)
(local $are f64) (local $aim f64)
(local $bre f64) (local $bim f64)
(local $rre f64) (local $rim f64)
(local $lre f64) (local $lim f64)
;; k = log2(n)
(i32.ctz (local.get $n))
local.set $k
;; i = 0; while (true) {...}
(i32.const 0)
local.set $i
block $i-break loop $i-continue
;; if (i >= n) break
(i32.ge_u (local.get $i) (local.get $n))
br_if $i-break
;; rbiin = revbit(k, i) * 16 + in
(i32.add
(i32.mul
(call $revbit (local.get $k) (local.get $i))
(i32.const 16))
(local.get $in))
local.set $rbiin
;; mem[i * 16 + in] = mem[rbiin]
(i32.add
(i32.mul (local.get $i) (i32.const 16))
(local.get $out))
(f64.load align=8 (local.get $rbiin))
f64.store align=8
;; mem[i * 16 + in + 8] = mem[rbiin + 8]
(i32.add
(i32.add
(i32.mul (local.get $i) (i32.const 16))
(local.get $out))
(i32.const 8))
(f64.load align=8 (i32.add (local.get $rbiin) (i32.const 8)))
f64.store align=8
;; i = i + 1
(i32.add (local.get $i) (i32.const 1))
local.set $i
br $i-continue
end end
;; nh = 1; while (true) {...}
(i32.const 1)
local.set $nh
block $nh-break loop $nh-continue
;; if (nh >= n) break
(i32.ge_u (local.get $nh) (local.get $n))
br_if $nh-break
;; t = t / 2
(f64.div (local.get $t) (f64.const 2.0))
local.set $t
;; s = 0; while (true) {...}
(i32.const 0)
local.set $s
block $s-break loop $s-continue
;; if (s >= n) break
(i32.ge_u (local.get $s) (local.get $n))
br_if $s-break
;; i = 0; while (true) {...}
(i32.const 0)
local.set $i
block $i-break loop $i-continue
;; if (i >= nh) break
(i32.ge_u (local.get $i) (local.get $nh))
br_if $i-break
;; lrei = (s + i) * 16 + out, limi = lrei + 8
(i32.add
(i32.mul (i32.add (local.get $s) (local.get $i)) (i32.const 16))
(local.get $out))
local.set $lrei
(i32.add (local.get $lrei) (i32.const 8))
local.set $limi
;; rrei = lrei + nh * 16, rimi = rrei + 8
(i32.add
(i32.mul (local.get $nh) (i32.const 16))
(local.get $lrei))
local.set $rrei
(i32.add (local.get $rrei) (i32.const 8))
local.set $rimi
;; are = mem[rrei], aim = mem[rimi]
(f64.load align=8 (local.get $rrei))
local.set $are
(f64.load align=8 (local.get $rimi))
local.set $aim
;; rad = t * i, bre = cos(rad), bim = sin(rad)
(f64.mul (local.get $t) (f64.convert_i32_u (local.get $i)))
local.set $rad
(call $cos (local.get $rad))
local.set $bre
(call $sin (local.get $rad))
local.set $bim
;; rre = are * bre - aim * bim, rim = are * bim + aim * bre
(f64.sub
(f64.mul (local.get $are) (local.get $bre))
(f64.mul (local.get $aim) (local.get $bim)))
local.set $rre
(f64.add
(f64.mul (local.get $are) (local.get $bim))
(f64.mul (local.get $aim) (local.get $bre)))
local.set $rim
;; lre = mem[lrei], lim = mem[limi]
(f64.load align=8 (local.get $lrei))
local.set $lre
(f64.load align=8 (local.get $limi))
local.set $lim
;; mem[lrei] = lre + rre, mem[limi] = lim + rim
(local.get $lrei)
(f64.add (local.get $lre) (local.get $rre))
f64.store align=8
(local.get $limi)
(f64.add (local.get $lim) (local.get $rim))
f64.store align=8
;; mem[rrei] = lre - rre, mem[rimi] = lim - rim
(local.get $rrei)
(f64.sub (local.get $lre) (local.get $rre))
f64.store align=8
(local.get $rimi)
(f64.sub (local.get $lim) (local.get $rim))
f64.store align=8
;; i = i + 1
(i32.add (local.get $i) (i32.const 1))
local.set $i
br $i-continue
end end
;; s = s + nh * 2
(i32.add (local.get $s) (i32.mul (local.get $nh) (i32.const 2)))
local.set $s
br $s-continue
end end
;; nh = nh * 2
(i32.mul (local.get $nh) (i32.const 2))
local.set $nh
br $nh-continue
end end
)
(func (export "fft") (param $n i32) (param $in i32) (param $out i32)
;; fftc(-2.0*pi, n, in, out)
(f64.mul (global.get $pi) (f64.const -2.0))
(local.get $n)
(local.get $in)
(local.get $out)
call $fftc
)
(func (export "ifft") (param $n i32) (param $in i32) (param $out i32)
(local $cur i32) (local $last i32)
(local $n64 f64)
;; fftc(2.0*pi, n, in, out)
(f64.mul (global.get $pi) (f64.const 2.0))
(local.get $n)
(local.get $in)
(local.get $out)
call $fftc
;; n64 = n
(f64.convert_i32_u (local.get $n))
local.set $n64
;; last = out + n * 16
(i32.add (local.get $out) (i32.mul (local.get $n) (i32.const 16)))
local.set $last
;; cur = out; while (true) {...}
(local.get $out)
local.set $cur
block $c-break loop $c-continue
;; if (cur >= last) break
(i32.ge_u (local.get $cur) (local.get $last))
br_if $c-break
;; mem[cur] = mem[cur] / n64
(local.get $cur)
(f64.div (f64.load align=8 (local.get $cur)) (local.get $n64))
f64.store align=8
;; cur = cur + 8
(i32.add (local.get $cur) (i32.const 8))
local.set $cur
br $c-continue
end end
)
)
<!doctype html>
<html>
<head>
<meta charset="utf-8">
<link rel="icon" href="data:">
<script type="module" src="./call-js-fft.js"></script>
<script type="module" src="./call-fft-browser.js"></script>
</head>
<body>
Check logs on webconsole
</body>
</html>
// ES6 module for importing from tan.wasm
export const PI = Math.PI;
export const sin = a => Math.sin(a);
export const cos = a => Math.cos(a);
// dummy function for calling inside wasm to check parameters
export function debug(...args) {
console.log(args);
}
// ES6 module for importing from tan.wasm
export const PI = Math.PI;
// To enable v8 option --wasm-math-intrinsics, export Math functions directly
export const sin = Math.sin;
export const cos = Math.cos;
// dummy function for calling inside wasm to check parameters
export function debug(...args) {
console.log(args);
}
{"type": "module"}
from http.server import SimpleHTTPRequestHandler
from socketserver import TCPServer
SimpleHTTPRequestHandler.extensions_map[".wasm"] = "application/wasm";
TCPServer(("", 8000), SimpleHTTPRequestHandler).serve_forever()
@bellbind
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment