Skip to content

Instantly share code, notes, and snippets.

@tjkendev
Last active January 3, 2017 16:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tjkendev/16a0c0fe5e5dca811ac0171f3491ff62 to your computer and use it in GitHub Desktop.
Save tjkendev/16a0c0fe5e5dca811ac0171f3491ff62 to your computer and use it in GitHub Desktop.
FMT (Fast Modulo Transform, 高速剰余変換)
# encoding: utf-8
# うまくFMTを計算するためのパラメタを決定する
# nを決めて、うまいPとωを決定する
# Pは素数であること、ωのn回の積で循環すること (ω^n ≡ 1 (mod P))が条件
n = 2**18
from math import sqrt
for x in xrange(1000, 2000):
P = x*n + 1
if any(P % i == 0 for i in xrange(2, int(sqrt(P))+1)):
continue
for w in xrange(2, 200):
s = set()
cur = 1
for i in xrange(n):
if cur in s:
break
s.add(cur)
cur = (cur * w) % P
if cur == 1:
if 2**(len(s)-1).bit_length() == len(s):
print "%d, %d*n+1: %d, %d, %d" % (w, x, len(s), n, cur) + " *" * (len(s) == n)
# encoding: utf-8
# FMTは、複素数で演算を行うFFTを整数環上で行うものである。
# ωとしてω^n ≡ 1 (mod P)となるω、P、nを用いる。
#
# ここでは、
# f(x) = a_{n-1}*x^{n-1} + ... + a_1*x + a_0とする。
#
# - 順変換
# 0 ≦ i ≦ n-1について、
# f_k = f(ω^k) (mod P) = \sum_{i=0}^{n-1} a_i*ω^{ik} (mod P)
# を計算
#
# - 逆変換
# 0 ≦ k ≦ n-1について
# a_i = \sum_{k=0}^{n-1} a_i*ω^{-ik} (mod P)
# を計算
# ========================================
# 計算に必要なパラメタ
# ω^n ≡ 1 (mod P) となるようなω, n, Pを選ぶ
# Pは素数, nは2^mが楽
omega = 55
n = 2**18
P = 1048*n + 1
rev = pow(omega, P-2, P)
# ========================================
# 愚直なO(N^2)の整数環DFT
# L個の個数制限をつけてちょっと早くしてる
def naive_dft(f, l=None):
F = [0]*n
l = l or n
for i in xrange(n):
base = pow(omega, i, P)
cur = 1
for j in xrange(l):
F[i] = (F[i] + cur*f[j]) % P
cur = (base * cur) % P
return F
def naive_idft(F, l=None):
f = [0]*n
l = l or n
for i in xrange(l):
base = pow(rev, i, P)
cur = 1
for j in xrange(n):
f[i] = (f[i] + cur*F[j]) % P
cur = (base * cur) % P
f[i] = (f[i] * pow(n, P-2, P)) % P
return f
# ========================================
# O(NlogN)の整数FMT
# 再帰的に求める。LIMで処理を少し早くしてる
def fmt_dfs(A, s, N, st, base, half, lim):
if N == 2:
a = A[s]; b = A[s+st]
return [(a+b)%P, (a+b*base)%P]
F = [0]*N
if s < lim:
N2 = N>>1; st2 = st<<1; base2 = pow(base, 2, P)
F0 = fmt_dfs(A, s , N2, st2, base2, half, lim)
F1 = fmt_dfs(A, s+st, N2, st2, base2, half, lim)
wk = 1
for k in xrange(N2):
U = F0[k]; V = F1[k] * wk
F[k] = (U + V) % P
F[k+N2] = (U + V*half) % P
wk = (wk * base) % P
return F
def fmt(f, l):
if l == 1:
return f
return fmt_dfs(f, 0, n, 1, omega, pow(omega, n/2, P), l)
def ifmt(F, l):
if l == 1:
return F
f = fmt_dfs(F, 0, n, 1, rev, pow(rev, n/2, P), n)
n_rev = pow(n, P-2, P)
return [(e * n_rev) % P for e in f]
# ========================================
# O(NlogN)の整数FMT
# bit反転を利用して、ボトムアップにループでFMTを行う
# ATC001 - C問題: 高速フーリエ変換 (AC): http://atc001.contest.atcoder.jp/submissions/1051678
# 配列要素のbit反転
def bit_reverse(d):
# X&(X-1)==0 --> X = 2^M
n = len(d)
ns = n>>1; nss = ns>>1
ns1 = ns + 1
i = 0
for j in xrange(0, ns, 2):
if j<i:
d[i], d[j] = d[j], d[i]
d[i+ns1], d[j+ns1] = d[j+ns1], d[i+ns1]
d[i+1], d[j+ns] = d[j+ns], d[i+1]
k = nss; i ^= k
while k > i:
k >>= 1; i ^= k
return d
# ボトムアップのFMTを行う
def fmt_bu(A, n, base, half, Q):
N = n
m = 1
while n>1:
n >>= 1
w = pow(base, n, Q)
wk = 1
for j in xrange(m):
for i in xrange(j, N, 2*m):
U = A[i]; V = (A[i+m]*wk) % Q
A[i] = (U + V) % Q
A[i+m] = (U + V*half) % Q
wk = (wk * w) % Q
m <<= 1
return A
def fmt(f, l, Q=P):
if l == 1: return f
A = f[:]
bit_reverse(A)
return fmt_bu(A, n, omega, pow(omega, n/2, Q), Q)
def ifmt(F, l, Q=P):
if l == 1: return F
A = F[:]
bit_reverse(A)
f = fmt_bu(A, n, rev, pow(rev, n/2, Q), Q)
n_rev = pow(n, Q-2, Q)
return [(e * n_rev) % Q for e in f]
def convolute(a, b, l, Q=P):
A = fmt(a, l, Q)
B = fmt(b, l, Q)
C = [(s * t) % Q for s, t in zip(A, B)]
c = ifmt(C, l, Q)
return c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment