Created
March 6, 2024 16:37
-
-
Save Janmajayamall/7a92556cff58af81de4917f58d070110 to your computer and use it in GitHub Desktop.
Special FFT implementation (Algorithm 1 of https://eprint.iacr.org/2018/1043.pdf)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from numpy import complex128 as C | |
def get_bit_reverse_array(a: [C], n:int) -> [C]: | |
assert is_power_of_two(n) | |
r = [C(0) for i in range(0,n)] | |
for i in range(0 , n): | |
bit_reverse_i = bit_reverse_value(i, number_of_bits(n)) | |
r[bit_reverse_i] = a[i] | |
return r | |
def find_mth_root_of_unity(m: int) -> C: | |
zeta = np.exp(C(2 * math.pi * 1j)/C(m)) | |
return zeta | |
def get_psi_powers(m: int) -> [C]: | |
# m^th primitive root of unity | |
psi = find_mth_root_of_unity(m) | |
# powers of m^th primitive root of unity | |
psi_powers = [] | |
for i in range(0, m): | |
psi_powers.append(psi ** i) | |
psi_powers.append(psi_powers[0]) | |
return psi_powers | |
def get_rot_group(N_half: int, M: int) -> [int]: | |
p = 1 | |
rot_group =[] | |
for i in range(0, N_half): | |
rot_group.append(p) | |
p *= 5 | |
p %= M | |
return rot_group | |
def specialFFT(a: [C], n: int, M: int) -> [C]: | |
assert len(a) == n | |
assert is_power_of_two(n) | |
a = get_bit_reverse_array(a, n) | |
psi_powers = get_psi_powers(M) | |
rot_group = get_rot_group(M >> 2, M) | |
length_n = 2 | |
while length_n <= n: | |
for i in range(0, n, +length_n): | |
lenh = length_n >> 1 | |
lenq = length_n << 2 | |
gap = M // lenq | |
for j in range(0, lenh, +1): | |
idx = (rot_group[j] % lenq) * gap | |
u = a[i + j] | |
v = a[i + j + lenh] | |
v *= psi_powers[idx] | |
a[i+j] = u + v | |
a[i+j+lenh] = u-v | |
length_n *= 2 | |
return a | |
def specialIFFT(a: [C], n: int, M: int) -> [C]: | |
assert len(a) == n | |
assert is_power_of_two(n) | |
a = copy.deepcopy(a) | |
length_n = n | |
psi_powers = get_psi_powers(M) | |
rot_group = get_rot_group(M >> 2, M) | |
while length_n >= 1: | |
for i in range(0, n, +length_n): | |
lenh = length_n >> 1 | |
lenq = length_n << 2 | |
gap = M // lenq | |
for j in range(0, lenh): | |
idx = (lenq -( rot_group[j] % lenq)) * gap | |
u = a[i+j] + a[i+j+lenh] | |
v = a[i+j] - a[i+j+lenh] | |
v = v * psi_powers[idx] | |
a[i+j] = u | |
a[i+j+lenh] = v | |
length_n >>= 1 | |
a = get_bit_reverse_array(a, n) | |
# multiply by 1/n and return | |
return [i/n for i in a] |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment