Last active
October 18, 2019 07:39
-
-
Save Fugoes/7541704089c931de71fe740eb5bc9362 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef COMBINATORICS_NTT_HPP | |
#define COMBINATORICS_NTT_HPP | |
#include <cstdint> | |
namespace ntt { | |
template<uint32_t P> | |
class Mod { | |
public: | |
Mod() = default; | |
static Mod<P> from(uint32_t x) { | |
return Mod<P>(x % P); | |
} | |
static Mod<P> from_small(uint32_t x) { | |
return Mod<P>(x); | |
} | |
Mod<P> pow(uint32_t exp) const { | |
uint64_t r = 1; | |
uint64_t a = _value; | |
while (exp != 0) { | |
if ((exp & 1u) != 0) r = (r * a) % P; | |
a = (a * a) % P; | |
exp >>= 1u; | |
} | |
return Mod<P>((uint32_t) r); | |
} | |
Mod<P> operator+(Mod<P> y) const { | |
return Mod<P>((_value + y._value) % P); | |
} | |
Mod<P> operator-(Mod<P> y) const { | |
return Mod<P>((_value + P - y._value) % P); | |
} | |
Mod<P> operator*(Mod<P> y) const { | |
uint64_t a = _value; | |
uint64_t b = y._value; | |
return Mod<P>((uint32_t) ((a * b) % P)); | |
} | |
Mod<P> inv() const { | |
return pow(P - 2); | |
} | |
bool operator==(Mod<P> y) const { | |
return _value == y._value; | |
} | |
uint32_t val() const { | |
return _value; | |
} | |
private: | |
uint32_t _value{0}; | |
explicit Mod(uint32_t x) { | |
_value = x; | |
} | |
static_assert(P < (UINT32_MAX / 2), "P too large"); | |
}; | |
// P = K * 2**R + 1 | |
template<uint32_t P, uint32_t K, uint32_t R, uint32_t G> | |
class NTT { | |
private: | |
static const Mod<P> *W_pow_shift; | |
static inline uint32_t reverse_next(uint32_t n_shift, uint32_t i) { | |
for (uint32_t h = (1u << (n_shift - 1)); h > 0; h >>= 1u) | |
if ((i & h) == 0) return i | h; else i ^= h; | |
return i; | |
} | |
static inline void reverse(uint32_t n_shift, Mod<P> *xs) { | |
for (uint32_t i = 0, j = 0; i < (1u << n_shift); i++) { | |
if (i > j) std::swap(xs[i], xs[j]); | |
j = reverse_next(n_shift, j); | |
} | |
} | |
static Mod<P> *gen_W_pow_shift() noexcept { | |
static Mod<P> rs[R + 1]; | |
rs[0] = Mod<P>::from_small(G).pow(K); | |
for (uint32_t i = 1; i <= R; i++) { | |
rs[i] = rs[i - 1] * rs[i - 1]; | |
} | |
return rs; | |
} | |
public: | |
static void apply(uint32_t n_shift, Mod<P> *xs) { | |
reverse(n_shift, xs); | |
for (uint32_t s = 0; s < n_shift; s++) { | |
uint32_t b = (1u << s); | |
Mod<P> wn = W_pow_shift[R - 1 - s]; | |
for (uint32_t i = 0; i < (1u << (n_shift - s - 1)); i++) { | |
Mod<P> w = Mod<P>::from_small(1); | |
for (uint32_t j = 0; j < b; j++) { | |
Mod<P> x = xs[2 * b * i + j]; | |
Mod<P> y = w * xs[2 * b * i + j + b]; | |
xs[2 * b * i + j] = x + y; | |
xs[2 * b * i + j + b] = x - y; | |
w = w * wn; | |
} | |
} | |
} | |
} | |
static void unapply(uint32_t n_shift, Mod<P> *xs) { | |
uint32_t n = (1u << n_shift); | |
Mod<P> n_inv = Mod<P>::from_small(n).inv(); | |
Mod<P> w = Mod<P>::from_small(1); | |
Mod<P> wn = W_pow_shift[R - n_shift]; | |
for (uint32_t i = 0; i < n; i++) { | |
xs[i] = xs[i] * w * n_inv; | |
w = w * wn; | |
} | |
apply(n_shift, xs); | |
for (uint32_t i = 0; i < n / 2; i++) { | |
std::swap(xs[i], xs[n - 1 - i]); | |
} | |
} | |
}; | |
template<uint32_t P, uint32_t K, uint32_t R, uint32_t G> | |
const Mod<P> *NTT<P, K, R, G>::W_pow_shift = NTT<P, K, R, G>::gen_W_pow_shift(); | |
} | |
#endif //COMBINATORICS_NTT_HPP |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment