Skip to content

Instantly share code, notes, and snippets.

@Fugoes
Last active October 18, 2019 07:39
Show Gist options
  • Save Fugoes/7541704089c931de71fe740eb5bc9362 to your computer and use it in GitHub Desktop.
Save Fugoes/7541704089c931de71fe740eb5bc9362 to your computer and use it in GitHub Desktop.
#ifndef COMBINATORICS_NTT_HPP
#define COMBINATORICS_NTT_HPP
#include <cstdint>
namespace ntt {
template<uint32_t P>
class Mod {
public:
Mod() = default;
static Mod<P> from(uint32_t x) {
return Mod<P>(x % P);
}
static Mod<P> from_small(uint32_t x) {
return Mod<P>(x);
}
Mod<P> pow(uint32_t exp) const {
uint64_t r = 1;
uint64_t a = _value;
while (exp != 0) {
if ((exp & 1u) != 0) r = (r * a) % P;
a = (a * a) % P;
exp >>= 1u;
}
return Mod<P>((uint32_t) r);
}
Mod<P> operator+(Mod<P> y) const {
return Mod<P>((_value + y._value) % P);
}
Mod<P> operator-(Mod<P> y) const {
return Mod<P>((_value + P - y._value) % P);
}
Mod<P> operator*(Mod<P> y) const {
uint64_t a = _value;
uint64_t b = y._value;
return Mod<P>((uint32_t) ((a * b) % P));
}
Mod<P> inv() const {
return pow(P - 2);
}
bool operator==(Mod<P> y) const {
return _value == y._value;
}
uint32_t val() const {
return _value;
}
private:
uint32_t _value{0};
explicit Mod(uint32_t x) {
_value = x;
}
static_assert(P < (UINT32_MAX / 2), "P too large");
};
// P = K * 2**R + 1
template<uint32_t P, uint32_t K, uint32_t R, uint32_t G>
class NTT {
private:
static const Mod<P> *W_pow_shift;
static inline uint32_t reverse_next(uint32_t n_shift, uint32_t i) {
for (uint32_t h = (1u << (n_shift - 1)); h > 0; h >>= 1u)
if ((i & h) == 0) return i | h; else i ^= h;
return i;
}
static inline void reverse(uint32_t n_shift, Mod<P> *xs) {
for (uint32_t i = 0, j = 0; i < (1u << n_shift); i++) {
if (i > j) std::swap(xs[i], xs[j]);
j = reverse_next(n_shift, j);
}
}
static Mod<P> *gen_W_pow_shift() noexcept {
static Mod<P> rs[R + 1];
rs[0] = Mod<P>::from_small(G).pow(K);
for (uint32_t i = 1; i <= R; i++) {
rs[i] = rs[i - 1] * rs[i - 1];
}
return rs;
}
public:
static void apply(uint32_t n_shift, Mod<P> *xs) {
reverse(n_shift, xs);
for (uint32_t s = 0; s < n_shift; s++) {
uint32_t b = (1u << s);
Mod<P> wn = W_pow_shift[R - 1 - s];
for (uint32_t i = 0; i < (1u << (n_shift - s - 1)); i++) {
Mod<P> w = Mod<P>::from_small(1);
for (uint32_t j = 0; j < b; j++) {
Mod<P> x = xs[2 * b * i + j];
Mod<P> y = w * xs[2 * b * i + j + b];
xs[2 * b * i + j] = x + y;
xs[2 * b * i + j + b] = x - y;
w = w * wn;
}
}
}
}
static void unapply(uint32_t n_shift, Mod<P> *xs) {
uint32_t n = (1u << n_shift);
Mod<P> n_inv = Mod<P>::from_small(n).inv();
Mod<P> w = Mod<P>::from_small(1);
Mod<P> wn = W_pow_shift[R - n_shift];
for (uint32_t i = 0; i < n; i++) {
xs[i] = xs[i] * w * n_inv;
w = w * wn;
}
apply(n_shift, xs);
for (uint32_t i = 0; i < n / 2; i++) {
std::swap(xs[i], xs[n - 1 - i]);
}
}
};
template<uint32_t P, uint32_t K, uint32_t R, uint32_t G>
const Mod<P> *NTT<P, K, R, G>::W_pow_shift = NTT<P, K, R, G>::gen_W_pow_shift();
}
#endif //COMBINATORICS_NTT_HPP
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment