Skip to content

Instantly share code, notes, and snippets.

@malb

malb/fft.py

Last active Mar 14, 2016
Embed
What would you like to do?
A simple FFT in Sage for testing
"""
Simple FFT for testing
"""
from sage.all import vector, srange, copy, log, floor
from sage.misc.misc import cputime
from sage.rings.all import ZZ, next_prime, GF, PolynomialRing
def fft2(f, w, n):
"""FFT
:param f:
:param w:
:param n:
:returns:
:rtype:
"""
P = f.parent()
k = log(ZZ(n), ZZ(2))
idx = [ZZ(i.digits(2, padto=k)[::-1], 2) for i in srange(ZZ(n))]
a = [f[i] for i in idx]
A = [0 for _ in range(n)]
for i in range(k):
for j in srange(ZZ(n/2)):
Pij = floor(j/(2**(k-1-i))) * 2**(k-1-i)
A[j] = a[2*j] + a[2*j+1] * w**Pij
A[j+n/2] = a[2*j] - a[2*j+1] * w**Pij
if i != k-1:
a = copy(A)
return P(A)
def fft(f, w, n):
"""FFT
:param f:
:param w:
:param n:
:returns:
:rtype:
"""
P = f.parent()
if n == 1:
return f
a0 = P([f[i] for i in range(0, n, 2)])
a1 = P([f[i] for i in range(1, n, 2)])
y0 = fft2(a0, w**2, n/2)
y1 = fft2(a1, w**2, n/2)
y = [0 for _ in range(n)]
wacc = 1
for k in range(n/2):
y[k] = y0[k] + wacc*y1[k]
y[k+(n/2)] = y0[k] - wacc*y1[k]
wacc = wacc * w
return P(y)
def mul_fft(f, g, w, n, fft=fft2):
"""Multiply f and g using FFT
:param f:
:param g:
:param w:
:param n:
:param fft:
:returns:
:rtype:
"""
P = f.parent()
phi = w.sqrt()
print "w", w, "phi", phi
f = P([c*phi**i for i, c in enumerate(f.list())])
g = P([c*phi**i for i, c in enumerate(g.list())])
F = fft(f, w, n)
G = fft(g, w, n)
H = P(vector(F.list()).pairwise_product(vector(G.list())).list())
h = fft(H, w**-1, n)//P(n)
h = P([c*phi**-i for i, c in enumerate(h.list())])
return h
def slowft(f, w, n):
"""Fourier transform.
:param f:
:param w:
:param n:
:returns:
:rtype:
"""
P = f.parent()
return P([f(w**i) for i in range(n)])
def islowft(f, w, n):
"""Inverse of Fourier transform.
:param f:
:param w:
:param n:
:returns:
:rtype:
"""
P = f.parent()
K = P.base_ring()
return P([f(w**-i)/K(n) for i in range(n)])
def mul_ft(f, g, w, n):
"""Multiply f and g using Fourier transform.
:param f:
:param g:
:param w:
:param n:
:returns:
:rtype:
"""
P = f.parent()
F = slowft(f, w, n)
G = slowft(g, w, n)
H = P(vector(F.list()).pairwise_product(vector(G.list())).list())
return islowft(H, w, n)
def test_it(n):
m = 2*n
q = 2
while True:
q = next_prime(q)
if (q % m) == 1:
break
K = GF(q)
w = K(1).nth_root(m)
P = PolynomialRing(K, 'x')
f = P.random_element(degree=n-1)
g = P.random_element(degree=n-1)
t = cputime()
r = mul_ft(f, g, w, m)
print "Fourier time", cputime(t)
assert(r == f*g)
w = K(1).nth_root(n)
t = cputime()
r = mul_fft(f, g, w, n, fft=fft2)
print "FFT time", cputime(t)
assert(fft(f, w, n) == fft2(f, w, n))
assert(r == ((f*g) % (P.gen()**n + 1)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.