Skip to content

Instantly share code, notes, and snippets.

@Janmajayamall
Created March 6, 2024 16:37
Show Gist options
  • Save Janmajayamall/7a92556cff58af81de4917f58d070110 to your computer and use it in GitHub Desktop.
Save Janmajayamall/7a92556cff58af81de4917f58d070110 to your computer and use it in GitHub Desktop.
Special FFT implementation (Algorithm 1 of https://eprint.iacr.org/2018/1043.pdf)
from numpy import complex128 as C
def get_bit_reverse_array(a: [C], n:int) -> [C]:
assert is_power_of_two(n)
r = [C(0) for i in range(0,n)]
for i in range(0 , n):
bit_reverse_i = bit_reverse_value(i, number_of_bits(n))
r[bit_reverse_i] = a[i]
return r
def find_mth_root_of_unity(m: int) -> C:
zeta = np.exp(C(2 * math.pi * 1j)/C(m))
return zeta
def get_psi_powers(m: int) -> [C]:
# m^th primitive root of unity
psi = find_mth_root_of_unity(m)
# powers of m^th primitive root of unity
psi_powers = []
for i in range(0, m):
psi_powers.append(psi ** i)
psi_powers.append(psi_powers[0])
return psi_powers
def get_rot_group(N_half: int, M: int) -> [int]:
p = 1
rot_group =[]
for i in range(0, N_half):
rot_group.append(p)
p *= 5
p %= M
return rot_group
def specialFFT(a: [C], n: int, M: int) -> [C]:
assert len(a) == n
assert is_power_of_two(n)
a = get_bit_reverse_array(a, n)
psi_powers = get_psi_powers(M)
rot_group = get_rot_group(M >> 2, M)
length_n = 2
while length_n <= n:
for i in range(0, n, +length_n):
lenh = length_n >> 1
lenq = length_n << 2
gap = M // lenq
for j in range(0, lenh, +1):
idx = (rot_group[j] % lenq) * gap
u = a[i + j]
v = a[i + j + lenh]
v *= psi_powers[idx]
a[i+j] = u + v
a[i+j+lenh] = u-v
length_n *= 2
return a
def specialIFFT(a: [C], n: int, M: int) -> [C]:
assert len(a) == n
assert is_power_of_two(n)
a = copy.deepcopy(a)
length_n = n
psi_powers = get_psi_powers(M)
rot_group = get_rot_group(M >> 2, M)
while length_n >= 1:
for i in range(0, n, +length_n):
lenh = length_n >> 1
lenq = length_n << 2
gap = M // lenq
for j in range(0, lenh):
idx = (lenq -( rot_group[j] % lenq)) * gap
u = a[i+j] + a[i+j+lenh]
v = a[i+j] - a[i+j+lenh]
v = v * psi_powers[idx]
a[i+j] = u
a[i+j+lenh] = v
length_n >>= 1
a = get_bit_reverse_array(a, n)
# multiply by 1/n and return
return [i/n for i in a]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment