-
-
Save yosupo06/e616e2dd1fd59bfa356d26135ca981a8 to your computer and use it in GitHub Desktop.
div2by1 テストコード / div2by1改造版 python版
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma GCC target("avx2") | |
#include <cassert> | |
#include <iostream> | |
#include <chrono> | |
#include <array> | |
#include <immintrin.h> | |
using namespace std; | |
using i32 = int; | |
using u32 = unsigned int; | |
using i64 = long long; | |
using u64 = unsigned long long; | |
// return (8n)! | |
u32 fact_simple(u32 n, u32 mod) { | |
u32 x = 1; | |
for (u32 i = 1; i <= 8 * n; i++) { | |
x = (u32)((u64)(x)*i % mod); | |
} | |
return x; | |
} | |
u32 fact_const_998244353(u32 n) { | |
u32 x = 1; | |
for (u32 i = 1; i <= 8 * n; i++) { | |
x = (u32)((u64)(x)*i % 998244353); | |
} | |
return x; | |
} | |
// x * inv_u32(x) = 1 (mod 2^32) | |
u32 inv_u32(const u32 x) { | |
assert(x % 2); | |
u32 inv = 1; | |
for (int i = 0; i < 5; i++) { | |
inv *= 2u - inv * x; | |
} | |
assert(x * inv == u32(1)); | |
return inv; | |
} | |
u32 fact_montgomery(u32 n, u32 mod) { | |
const u32 n_inv = -inv_u32(mod); | |
auto mul_reduce = [&](u32 x, u32 y) { | |
u64 z = u64(1) * x * y; | |
z += u64(u32(z) * n_inv) * mod; | |
return u32(z >> 32); | |
}; | |
const u32 one = (u32)((u64(1) << 32) % mod); | |
u32 x = one, y = one; | |
for (u32 i = 1; i <= 8 * n; i++) { | |
x = mul_reduce(x, y); | |
y += one; | |
if (y >= 2 * mod) y -= 2 * mod; | |
} | |
return mul_reduce(x, 1) % mod; | |
} | |
u32 fact_simd_montgomery(u32 n, u32 mod) { | |
const __m256i_u mod_x = _mm256_set1_epi32(mod); | |
const __m256i_u mod_2_x = _mm256_set1_epi32(mod * 2); | |
const u32 n_inv = -inv_u32(mod); | |
const __m256i_u n_inv_x = _mm256_set1_epi32(n_inv); | |
auto mul_reduce = [&](__m256i_u x, __m256i_u y) { | |
__m256i_u z_even = _mm256_mul_epu32(x, y); | |
__m256i_u z_odd = _mm256_mul_epu32(_mm256_srli_epi64(x, 32), | |
_mm256_srli_epi64(y, 32)); | |
z_even += _mm256_mul_epu32(_mm256_mul_epu32(z_even, n_inv_x), mod_x); | |
z_odd += _mm256_mul_epu32(_mm256_mul_epu32(z_odd, n_inv_x), mod_x); | |
z_even = _mm256_srli_epi64(z_even, 32); | |
return _mm256_blend_epi32(z_even, z_odd, 0b10101010); | |
}; | |
const u32 one = (u32)((u64(1) << 32) % mod); | |
const __m256i_u eight_x = _mm256_set1_epi32((u32)(u64(8) * one % mod)); | |
// x = [1, 1, 1, 1, 1, 1, 1, 1] | |
__m256i_u x = _mm256_set1_epi32(one); | |
// y = [1, 2, 3, 4, 5, 6, 7, 8] | |
__m256i_u y = [&]() { | |
std::array<u32, 8> _y; | |
for (u32 i = 0; i < 8; i++) { | |
_y[i] = u32(u64(i + 1) * one % mod); | |
}; | |
return _mm256_loadu_si256((__m256i_u*)_y.data()); | |
}(); | |
for (u32 i = 1; i <= n; i++) { | |
// x *= y | |
x = mul_reduce(x, y); | |
// y += [8, 8, 8, 8, 8, 8, 8, 8] | |
y = _mm256_add_epi32(y, eight_x); | |
y = _mm256_min_epu32(y, _mm256_sub_epi32(y, mod_2_x)); | |
} | |
x = mul_reduce(x, _mm256_set1_epi32(1)); | |
std::array<u32, 8> _x; | |
_mm256_storeu_si256((__m256i_u*)_x.data(), x); | |
u32 z = 1; | |
for (int i = 0; i < 8; i++) { | |
z = u32(u64(1) * z * _x[i] % mod); | |
} | |
return z; | |
} | |
u32 fact_div2by1(u32 n, u32 mod) { | |
const u32 d = [&]() { | |
u32 _d = mod; | |
while (_d < (u32(1) << 31)) _d *= 2; | |
return _d; | |
}(); | |
const u32 v = (u32)((u64(-1) / d) - (u64(1) << 32)); | |
u32 x = 1; | |
for (u32 i = 1; i <= 8 * n; i++) { | |
u64 u = u64(1) * x * i; | |
u32 u1 = u32(u >> 32); | |
u32 q = u32((u64(1) * v * u1) >> 32) + u1; | |
u -= u64(1) * q * d; | |
if (u >= u64(2) * d) u -= u64(2) * d; | |
if (u >= d) u -= d; | |
x = u32(u); | |
} | |
return x % mod; | |
} | |
u32 fact_simd_div2by1(u32 n, u32 mod) { | |
const u32 d = [&]() { | |
u32 _d = mod; | |
while (_d < (u32(1) << 31)) _d *= 2; | |
return _d; | |
}(); | |
const __m256i d_x = _mm256_set1_epi32(d); | |
const __m256i d_4x = _mm256_set1_epi64x(d); | |
const __m256i d_n1_4x = _mm256_set1_epi64x(d - 1); | |
const __m256i d_2_4x = _mm256_set1_epi64x(u64(2) * d); | |
const __m256i d_2_n1_4x = _mm256_set1_epi64x(u64(2) * d - 1); | |
const u32 v = (u32)((u64(-1) / d) - (u64(1) << 32)); | |
const __m256i v_x = _mm256_set1_epi32(v); | |
const __m256i_u eight_x = _mm256_set1_epi32(8); | |
__m256i_u x = _mm256_set1_epi32(1); | |
__m256i_u y = [&]() { | |
std::array<u32, 8> _y; | |
for (u32 i = 0; i < 8; i++) { | |
_y[i] = i + 1; | |
}; | |
return _mm256_loadu_si256((__m256i_u*)_y.data()); | |
}(); | |
for (u32 i = 1; i <= n; i++) { | |
// u64 u = u64(1) * x * y; | |
__m256i_u u_even = _mm256_mul_epu32(x, y); | |
__m256i_u u_odd = _mm256_mul_epu32(_mm256_srli_epi64(x, 32), | |
_mm256_srli_epi64(y, 32)); | |
// u32 u1 = u32(u >> 32); | |
__m256i_u u1_even = _mm256_srli_epi64(u_even, 32); | |
__m256i_u u1_odd = _mm256_srli_epi64(u_odd, 32); | |
// u32 q = u32((u64(1) * v * u1) >> 32) + u1; | |
__m256i_u q_even = _mm256_add_epi64( | |
_mm256_srli_epi64(_mm256_mul_epu32(v_x, u1_even), 32), u1_even); | |
__m256i_u q_odd = _mm256_add_epi64( | |
_mm256_srli_epi64(_mm256_mul_epu32(v_x, u1_odd), 32), u1_odd); | |
// u -= u64(1) * q * d; | |
u_even = _mm256_sub_epi64(u_even, _mm256_mul_epu32(q_even, d_x)); | |
u_odd = _mm256_sub_epi64(u_odd, _mm256_mul_epu32(q_odd, d_x)); | |
// if (u >= u64(2) * d) u -= u64(2) * d; | |
u_even = _mm256_sub_epi64( | |
u_even, | |
_mm256_and_si256(_mm256_cmpgt_epi64(u_even, d_2_n1_4x), d_2_4x)); | |
u_odd = _mm256_sub_epi64( | |
u_odd, | |
_mm256_and_si256(_mm256_cmpgt_epi64(u_odd, d_2_n1_4x), d_2_4x)); | |
// if (u >= d) u -= d; | |
u_even = _mm256_sub_epi64( | |
u_even, | |
_mm256_and_si256(_mm256_cmpgt_epi64(u_even, d_n1_4x), d_4x)); | |
u_odd = _mm256_sub_epi64( | |
u_odd, | |
_mm256_and_si256(_mm256_cmpgt_epi64(u_odd, d_n1_4x), d_4x)); | |
// x = u | |
x = _mm256_blend_epi32(u_even, _mm256_slli_epi64(u_odd, 32), 0b10101010); | |
// y += [8, 8, 8, 8, 8, 8, 8, 8] | |
y = _mm256_add_epi32(y, eight_x); | |
} | |
std::array<u32, 8> _x; | |
_mm256_storeu_si256((__m256i_u*)_x.data(), x); | |
u32 z = 1; | |
for (int i = 0; i < 8; i++) { | |
z = u32(u64(1) * z * _x[i] % mod); | |
} | |
return z; | |
} | |
u32 fact_my_div2by1(u32 n, u32 mod) { | |
const u32 v = (u32)((((u64(1) << 62) - 1) / mod) - (u64(1) << 32)); | |
u32 x = 1; | |
for (u32 i = 1; i <= 8 * n; i++) { | |
u64 u = u64(1) * x * i; | |
u32 u1 = (u32)(u >> 30); | |
u32 q = u32((u64(1) * v * u1) >> 32) + u1; | |
x = (u32)(u) - (u32)(u64(1) * q * mod); | |
if (x >= 2 * mod) x -= 2 * mod; | |
} | |
return (x % mod + mod) % mod; | |
} | |
u32 fact_simd_my_div2by1(u32 n, u32 mod) { | |
const __m256i_u mod_x = _mm256_set1_epi32(mod); | |
const __m256i_u mod_2_x = _mm256_set1_epi32(mod * 2); | |
const u32 v = (u32)((((u64(1) << 62) - 1) / mod) - (u64(1) << 32)); | |
const __m256i v_x = _mm256_set1_epi32(v); | |
const __m256i_u eight_x = _mm256_set1_epi32(8); | |
__m256i_u x = _mm256_set1_epi32(1); | |
__m256i_u y = [&]() { | |
std::array<u32, 8> _y; | |
for (u32 i = 0; i < 8; i++) { | |
_y[i] = i + 1; | |
}; | |
return _mm256_loadu_si256((__m256i_u*)_y.data()); | |
}(); | |
for (u32 i = 1; i <= n; i++) { | |
__m256i_u u_even = _mm256_mul_epu32(x, y); | |
__m256i_u u_odd = _mm256_mul_epu32(_mm256_srli_epi64(x, 32), | |
_mm256_srli_epi64(y, 32)); | |
__m256i_u u1_even = _mm256_srli_epi64(u_even, 30); | |
__m256i_u u1_odd = _mm256_srli_epi64(u_odd, 30); | |
__m256i_u q_even = _mm256_add_epi64( | |
_mm256_srli_epi64(_mm256_mul_epu32(v_x, u1_even), 32), u1_even); | |
__m256i_u q_odd = _mm256_add_epi64( | |
_mm256_srli_epi64(_mm256_mul_epu32(v_x, u1_odd), 32), u1_odd); | |
q_even = _mm256_mul_epu32(q_even, mod_x); | |
q_odd = _mm256_mul_epu32(q_odd, mod_x); | |
__m256i_u q = _mm256_blend_epi32(q_even, _mm256_slli_epi64(q_odd, 32), | |
0b10101010); | |
__m256i_u u0 = _mm256_blend_epi32(u_even, _mm256_slli_epi64(u_odd, 32), | |
0b10101010); | |
x = _mm256_sub_epi32(u0, q); | |
x = _mm256_min_epu32(x, _mm256_sub_epi32(x, mod_2_x)); | |
// y += [8, 8, 8, 8, 8, 8, 8, 8] | |
y = _mm256_add_epi32(y, eight_x); | |
} | |
std::array<u32, 8> _x; | |
_mm256_storeu_si256((__m256i_u*)_x.data(), x); | |
u32 z = 1; | |
for (int i = 0; i < 8; i++) { | |
z = u32(u64(1) * z * _x[i] % mod); | |
} | |
return z; | |
} | |
int main() { | |
i32 n, mod; | |
cin >> n >> mod; | |
{ | |
auto begin = chrono::steady_clock::now(); | |
i32 r = fact_simple(n, mod); | |
auto now = chrono::steady_clock::now(); | |
int msecs = int( | |
chrono::duration_cast<chrono::milliseconds>(now - begin).count()); | |
cout << "simple: " << msecs << "ms " << r << endl; | |
} | |
{ | |
auto begin = chrono::steady_clock::now(); | |
i32 r = fact_const_998244353(n); | |
auto now = chrono::steady_clock::now(); | |
int msecs = | |
int(chrono::duration_cast<chrono::milliseconds>(now - begin).count()); | |
cout << "const 998244353: " << msecs << "ms " << r << endl; | |
} | |
{ | |
auto begin = chrono::steady_clock::now(); | |
int r = fact_montgomery(n, mod); | |
auto now = chrono::steady_clock::now(); | |
int msecs = | |
int(chrono::duration_cast<chrono::milliseconds>(now - begin).count()); | |
cout << "montgomery: " << msecs << "ms " << r << endl; | |
} | |
{ | |
auto begin = chrono::steady_clock::now(); | |
int r = fact_simd_montgomery(n, mod); | |
auto now = chrono::steady_clock::now(); | |
int msecs = int( | |
chrono::duration_cast<chrono::milliseconds>(now - begin).count()); | |
cout << "simd montgomery: " << msecs << "ms " << r << endl; | |
} | |
{ | |
auto begin = chrono::steady_clock::now(); | |
int r = fact_div2by1(n, mod); | |
auto now = chrono::steady_clock::now(); | |
int msecs = int( | |
chrono::duration_cast<chrono::milliseconds>(now - begin).count()); | |
cout << "div2by1: " << msecs << "ms " << r << endl; | |
} | |
{ | |
auto begin = chrono::steady_clock::now(); | |
int r = fact_simd_div2by1(n, mod); | |
auto now = chrono::steady_clock::now(); | |
int msecs = int( | |
chrono::duration_cast<chrono::milliseconds>(now - begin).count()); | |
cout << "simd div2by1: " << msecs << "ms " << r << endl; | |
} | |
{ | |
auto begin = chrono::steady_clock::now(); | |
int r = fact_my_div2by1(n, mod); | |
auto now = chrono::steady_clock::now(); | |
int msecs = int( | |
chrono::duration_cast<chrono::milliseconds>(now - begin).count()); | |
cout << "my_div2by1: " << msecs << "ms " << r << endl; | |
} | |
{ | |
auto begin = chrono::steady_clock::now(); | |
int r = fact_simd_my_div2by1(n, mod); | |
auto now = chrono::steady_clock::now(); | |
int msecs = int( | |
chrono::duration_cast<chrono::milliseconds>(now - begin).count()); | |
cout << "simd my_div2by1: " << msecs << "ms " << r << endl; | |
} | |
return 0; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from random import randint | |
# return (u % m) or (u % m) + m | |
def my_div2by1(u, m): | |
# Input constraints | |
assert(2**29 <= m < 2**30) | |
assert(0 <= u < m * 2**32 <= 2**62) | |
v = 2**62 // m - 2**32 | |
if m == 2**29: | |
v -= 1 | |
assert(0 <= v < 2**32) | |
u1 = u // 2**30 | |
assert(0 <= u < 2**62) | |
assert(0 <= u1 < 2**32) | |
q = (v * u1) // 2**32 + u1 | |
if m != 2**29: | |
assert(q == ((v + 2**32) * u1) // 2**32 == 2**62 // m * u1 // 2**32) | |
assert(q <= (2**30 * u1) // m) | |
# q >= floor(((2**62 / m - 1) * u1) / 2**32) | |
# = floor((2**30 * u1) / m - u1 / 2**32) | |
# > (2**30 * u1) / m - 1 (because u1 < 2**32) | |
# note: m == 2**29 -> casework | |
assert(q == ((2**30 * u1) // m - 1) or q == (2**30 * u1) // m) | |
assert(q <= (2**30 * u1) // m <= u // m < 2**32) | |
w = u - q * m | |
assert(0 <= u - 2**30 * u1 < 2**30 <= 2 * m) | |
assert(0 <= 2**30 * u1 - q * m < 2 * m) | |
assert(0 <= w < 4 * m < 2**32) | |
if w >= 2 * m: | |
w -= 2 * m | |
return w | |
for i in range(10000): | |
m = randint(2**29, 2**30 - 1) | |
x = randint(0, m * 2**32 - 1) | |
y = my_div2by1(x, m) | |
assert(y == x % m or y == x % m + m) | |
for i in range(10000): | |
m = 2**29 | |
x = randint(0, m * 2**32 - 1) | |
y = my_div2by1(x, m) | |
assert(y == x % m or y == x % m + m) | |
for i in range(10000): | |
m = 2**30 - 1 | |
x = randint(0, m * 2**32 - 1) | |
y = my_div2by1(x, m) | |
assert(y == x % m or y == x % m + m) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment