Skip to content

Instantly share code, notes, and snippets.

@yosupo06
Last active April 22, 2024 21:37
Show Gist options
  • Save yosupo06/e616e2dd1fd59bfa356d26135ca981a8 to your computer and use it in GitHub Desktop.
Save yosupo06/e616e2dd1fd59bfa356d26135ca981a8 to your computer and use it in GitHub Desktop.
div2by1 テストコード / div2by1改造版 python版
#pragma GCC target("avx2")
#include <cassert>
#include <iostream>
#include <chrono>
#include <array>
#include <immintrin.h>
using namespace std;
using i32 = int;
using u32 = unsigned int;
using i64 = long long;
using u64 = unsigned long long;
// return (8n)!
u32 fact_simple(u32 n, u32 mod) {
u32 x = 1;
for (u32 i = 1; i <= 8 * n; i++) {
x = (u32)((u64)(x)*i % mod);
}
return x;
}
u32 fact_const_998244353(u32 n) {
u32 x = 1;
for (u32 i = 1; i <= 8 * n; i++) {
x = (u32)((u64)(x)*i % 998244353);
}
return x;
}
// x * inv_u32(x) = 1 (mod 2^32)
u32 inv_u32(const u32 x) {
assert(x % 2);
u32 inv = 1;
for (int i = 0; i < 5; i++) {
inv *= 2u - inv * x;
}
assert(x * inv == u32(1));
return inv;
}
u32 fact_montgomery(u32 n, u32 mod) {
const u32 n_inv = -inv_u32(mod);
auto mul_reduce = [&](u32 x, u32 y) {
u64 z = u64(1) * x * y;
z += u64(u32(z) * n_inv) * mod;
return u32(z >> 32);
};
const u32 one = (u32)((u64(1) << 32) % mod);
u32 x = one, y = one;
for (u32 i = 1; i <= 8 * n; i++) {
x = mul_reduce(x, y);
y += one;
if (y >= 2 * mod) y -= 2 * mod;
}
return mul_reduce(x, 1) % mod;
}
u32 fact_simd_montgomery(u32 n, u32 mod) {
const __m256i_u mod_x = _mm256_set1_epi32(mod);
const __m256i_u mod_2_x = _mm256_set1_epi32(mod * 2);
const u32 n_inv = -inv_u32(mod);
const __m256i_u n_inv_x = _mm256_set1_epi32(n_inv);
auto mul_reduce = [&](__m256i_u x, __m256i_u y) {
__m256i_u z_even = _mm256_mul_epu32(x, y);
__m256i_u z_odd = _mm256_mul_epu32(_mm256_srli_epi64(x, 32),
_mm256_srli_epi64(y, 32));
z_even += _mm256_mul_epu32(_mm256_mul_epu32(z_even, n_inv_x), mod_x);
z_odd += _mm256_mul_epu32(_mm256_mul_epu32(z_odd, n_inv_x), mod_x);
z_even = _mm256_srli_epi64(z_even, 32);
return _mm256_blend_epi32(z_even, z_odd, 0b10101010);
};
const u32 one = (u32)((u64(1) << 32) % mod);
const __m256i_u eight_x = _mm256_set1_epi32((u32)(u64(8) * one % mod));
// x = [1, 1, 1, 1, 1, 1, 1, 1]
__m256i_u x = _mm256_set1_epi32(one);
// y = [1, 2, 3, 4, 5, 6, 7, 8]
__m256i_u y = [&]() {
std::array<u32, 8> _y;
for (u32 i = 0; i < 8; i++) {
_y[i] = u32(u64(i + 1) * one % mod);
};
return _mm256_loadu_si256((__m256i_u*)_y.data());
}();
for (u32 i = 1; i <= n; i++) {
// x *= y
x = mul_reduce(x, y);
// y += [8, 8, 8, 8, 8, 8, 8, 8]
y = _mm256_add_epi32(y, eight_x);
y = _mm256_min_epu32(y, _mm256_sub_epi32(y, mod_2_x));
}
x = mul_reduce(x, _mm256_set1_epi32(1));
std::array<u32, 8> _x;
_mm256_storeu_si256((__m256i_u*)_x.data(), x);
u32 z = 1;
for (int i = 0; i < 8; i++) {
z = u32(u64(1) * z * _x[i] % mod);
}
return z;
}
u32 fact_div2by1(u32 n, u32 mod) {
const u32 d = [&]() {
u32 _d = mod;
while (_d < (u32(1) << 31)) _d *= 2;
return _d;
}();
const u32 v = (u32)((u64(-1) / d) - (u64(1) << 32));
u32 x = 1;
for (u32 i = 1; i <= 8 * n; i++) {
u64 u = u64(1) * x * i;
u32 u1 = u32(u >> 32);
u32 q = u32((u64(1) * v * u1) >> 32) + u1;
u -= u64(1) * q * d;
if (u >= u64(2) * d) u -= u64(2) * d;
if (u >= d) u -= d;
x = u32(u);
}
return x % mod;
}
u32 fact_simd_div2by1(u32 n, u32 mod) {
const u32 d = [&]() {
u32 _d = mod;
while (_d < (u32(1) << 31)) _d *= 2;
return _d;
}();
const __m256i d_x = _mm256_set1_epi32(d);
const __m256i d_4x = _mm256_set1_epi64x(d);
const __m256i d_n1_4x = _mm256_set1_epi64x(d - 1);
const __m256i d_2_4x = _mm256_set1_epi64x(u64(2) * d);
const __m256i d_2_n1_4x = _mm256_set1_epi64x(u64(2) * d - 1);
const u32 v = (u32)((u64(-1) / d) - (u64(1) << 32));
const __m256i v_x = _mm256_set1_epi32(v);
const __m256i_u eight_x = _mm256_set1_epi32(8);
__m256i_u x = _mm256_set1_epi32(1);
__m256i_u y = [&]() {
std::array<u32, 8> _y;
for (u32 i = 0; i < 8; i++) {
_y[i] = i + 1;
};
return _mm256_loadu_si256((__m256i_u*)_y.data());
}();
for (u32 i = 1; i <= n; i++) {
// u64 u = u64(1) * x * y;
__m256i_u u_even = _mm256_mul_epu32(x, y);
__m256i_u u_odd = _mm256_mul_epu32(_mm256_srli_epi64(x, 32),
_mm256_srli_epi64(y, 32));
// u32 u1 = u32(u >> 32);
__m256i_u u1_even = _mm256_srli_epi64(u_even, 32);
__m256i_u u1_odd = _mm256_srli_epi64(u_odd, 32);
// u32 q = u32((u64(1) * v * u1) >> 32) + u1;
__m256i_u q_even = _mm256_add_epi64(
_mm256_srli_epi64(_mm256_mul_epu32(v_x, u1_even), 32), u1_even);
__m256i_u q_odd = _mm256_add_epi64(
_mm256_srli_epi64(_mm256_mul_epu32(v_x, u1_odd), 32), u1_odd);
// u -= u64(1) * q * d;
u_even = _mm256_sub_epi64(u_even, _mm256_mul_epu32(q_even, d_x));
u_odd = _mm256_sub_epi64(u_odd, _mm256_mul_epu32(q_odd, d_x));
// if (u >= u64(2) * d) u -= u64(2) * d;
u_even = _mm256_sub_epi64(
u_even,
_mm256_and_si256(_mm256_cmpgt_epi64(u_even, d_2_n1_4x), d_2_4x));
u_odd = _mm256_sub_epi64(
u_odd,
_mm256_and_si256(_mm256_cmpgt_epi64(u_odd, d_2_n1_4x), d_2_4x));
// if (u >= d) u -= d;
u_even = _mm256_sub_epi64(
u_even,
_mm256_and_si256(_mm256_cmpgt_epi64(u_even, d_n1_4x), d_4x));
u_odd = _mm256_sub_epi64(
u_odd,
_mm256_and_si256(_mm256_cmpgt_epi64(u_odd, d_n1_4x), d_4x));
// x = u
x = _mm256_blend_epi32(u_even, _mm256_slli_epi64(u_odd, 32), 0b10101010);
// y += [8, 8, 8, 8, 8, 8, 8, 8]
y = _mm256_add_epi32(y, eight_x);
}
std::array<u32, 8> _x;
_mm256_storeu_si256((__m256i_u*)_x.data(), x);
u32 z = 1;
for (int i = 0; i < 8; i++) {
z = u32(u64(1) * z * _x[i] % mod);
}
return z;
}
u32 fact_my_div2by1(u32 n, u32 mod) {
const u32 v = (u32)((((u64(1) << 62) - 1) / mod) - (u64(1) << 32));
u32 x = 1;
for (u32 i = 1; i <= 8 * n; i++) {
u64 u = u64(1) * x * i;
u32 u1 = (u32)(u >> 30);
u32 q = u32((u64(1) * v * u1) >> 32) + u1;
x = (u32)(u) - (u32)(u64(1) * q * mod);
if (x >= 2 * mod) x -= 2 * mod;
}
return (x % mod + mod) % mod;
}
u32 fact_simd_my_div2by1(u32 n, u32 mod) {
const __m256i_u mod_x = _mm256_set1_epi32(mod);
const __m256i_u mod_2_x = _mm256_set1_epi32(mod * 2);
const u32 v = (u32)((((u64(1) << 62) - 1) / mod) - (u64(1) << 32));
const __m256i v_x = _mm256_set1_epi32(v);
const __m256i_u eight_x = _mm256_set1_epi32(8);
__m256i_u x = _mm256_set1_epi32(1);
__m256i_u y = [&]() {
std::array<u32, 8> _y;
for (u32 i = 0; i < 8; i++) {
_y[i] = i + 1;
};
return _mm256_loadu_si256((__m256i_u*)_y.data());
}();
for (u32 i = 1; i <= n; i++) {
__m256i_u u_even = _mm256_mul_epu32(x, y);
__m256i_u u_odd = _mm256_mul_epu32(_mm256_srli_epi64(x, 32),
_mm256_srli_epi64(y, 32));
__m256i_u u1_even = _mm256_srli_epi64(u_even, 30);
__m256i_u u1_odd = _mm256_srli_epi64(u_odd, 30);
__m256i_u q_even = _mm256_add_epi64(
_mm256_srli_epi64(_mm256_mul_epu32(v_x, u1_even), 32), u1_even);
__m256i_u q_odd = _mm256_add_epi64(
_mm256_srli_epi64(_mm256_mul_epu32(v_x, u1_odd), 32), u1_odd);
q_even = _mm256_mul_epu32(q_even, mod_x);
q_odd = _mm256_mul_epu32(q_odd, mod_x);
__m256i_u q = _mm256_blend_epi32(q_even, _mm256_slli_epi64(q_odd, 32),
0b10101010);
__m256i_u u0 = _mm256_blend_epi32(u_even, _mm256_slli_epi64(u_odd, 32),
0b10101010);
x = _mm256_sub_epi32(u0, q);
x = _mm256_min_epu32(x, _mm256_sub_epi32(x, mod_2_x));
// y += [8, 8, 8, 8, 8, 8, 8, 8]
y = _mm256_add_epi32(y, eight_x);
}
std::array<u32, 8> _x;
_mm256_storeu_si256((__m256i_u*)_x.data(), x);
u32 z = 1;
for (int i = 0; i < 8; i++) {
z = u32(u64(1) * z * _x[i] % mod);
}
return z;
}
int main() {
i32 n, mod;
cin >> n >> mod;
{
auto begin = chrono::steady_clock::now();
i32 r = fact_simple(n, mod);
auto now = chrono::steady_clock::now();
int msecs = int(
chrono::duration_cast<chrono::milliseconds>(now - begin).count());
cout << "simple: " << msecs << "ms " << r << endl;
}
{
auto begin = chrono::steady_clock::now();
i32 r = fact_const_998244353(n);
auto now = chrono::steady_clock::now();
int msecs =
int(chrono::duration_cast<chrono::milliseconds>(now - begin).count());
cout << "const 998244353: " << msecs << "ms " << r << endl;
}
{
auto begin = chrono::steady_clock::now();
int r = fact_montgomery(n, mod);
auto now = chrono::steady_clock::now();
int msecs =
int(chrono::duration_cast<chrono::milliseconds>(now - begin).count());
cout << "montgomery: " << msecs << "ms " << r << endl;
}
{
auto begin = chrono::steady_clock::now();
int r = fact_simd_montgomery(n, mod);
auto now = chrono::steady_clock::now();
int msecs = int(
chrono::duration_cast<chrono::milliseconds>(now - begin).count());
cout << "simd montgomery: " << msecs << "ms " << r << endl;
}
{
auto begin = chrono::steady_clock::now();
int r = fact_div2by1(n, mod);
auto now = chrono::steady_clock::now();
int msecs = int(
chrono::duration_cast<chrono::milliseconds>(now - begin).count());
cout << "div2by1: " << msecs << "ms " << r << endl;
}
{
auto begin = chrono::steady_clock::now();
int r = fact_simd_div2by1(n, mod);
auto now = chrono::steady_clock::now();
int msecs = int(
chrono::duration_cast<chrono::milliseconds>(now - begin).count());
cout << "simd div2by1: " << msecs << "ms " << r << endl;
}
{
auto begin = chrono::steady_clock::now();
int r = fact_my_div2by1(n, mod);
auto now = chrono::steady_clock::now();
int msecs = int(
chrono::duration_cast<chrono::milliseconds>(now - begin).count());
cout << "my_div2by1: " << msecs << "ms " << r << endl;
}
{
auto begin = chrono::steady_clock::now();
int r = fact_simd_my_div2by1(n, mod);
auto now = chrono::steady_clock::now();
int msecs = int(
chrono::duration_cast<chrono::milliseconds>(now - begin).count());
cout << "simd my_div2by1: " << msecs << "ms " << r << endl;
}
return 0;
}
from random import randint
# return (u % m) or (u % m) + m
def my_div2by1(u, m):
# Input constraints
assert(2**29 <= m < 2**30)
assert(0 <= u < m * 2**32 <= 2**62)
v = 2**62 // m - 2**32
if m == 2**29:
v -= 1
assert(0 <= v < 2**32)
u1 = u // 2**30
assert(0 <= u < 2**62)
assert(0 <= u1 < 2**32)
q = (v * u1) // 2**32 + u1
if m != 2**29:
assert(q == ((v + 2**32) * u1) // 2**32 == 2**62 // m * u1 // 2**32)
assert(q <= (2**30 * u1) // m)
# q >= floor(((2**62 / m - 1) * u1) / 2**32)
# = floor((2**30 * u1) / m - u1 / 2**32)
# > (2**30 * u1) / m - 1 (because u1 < 2**32)
# note: m == 2**29 -> casework
assert(q == ((2**30 * u1) // m - 1) or q == (2**30 * u1) // m)
assert(q <= (2**30 * u1) // m <= u // m < 2**32)
w = u - q * m
assert(0 <= u - 2**30 * u1 < 2**30 <= 2 * m)
assert(0 <= 2**30 * u1 - q * m < 2 * m)
assert(0 <= w < 4 * m < 2**32)
if w >= 2 * m:
w -= 2 * m
return w
for i in range(10000):
m = randint(2**29, 2**30 - 1)
x = randint(0, m * 2**32 - 1)
y = my_div2by1(x, m)
assert(y == x % m or y == x % m + m)
for i in range(10000):
m = 2**29
x = randint(0, m * 2**32 - 1)
y = my_div2by1(x, m)
assert(y == x % m or y == x % m + m)
for i in range(10000):
m = 2**30 - 1
x = randint(0, m * 2**32 - 1)
y = my_div2by1(x, m)
assert(y == x % m or y == x % m + m)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment