Last active
April 2, 2023 10:07
-
-
Save cgiosy/a7848b8bd714a9bab668b717ece1ea51 to your computer and use it in GitHub Desktop.
AVX2 Barrett Reduction
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#pragma GCC optimize("O3") | |
#pragma GCC target("avx,avx2,fma") | |
#include <cstdio> | |
#include <algorithm> | |
#include <x86intrin.h> | |
// https://arxiv.org/ftp/arxiv/papers/1407/1407.3383.pdf | |
struct Mod { | |
int m, s, n; | |
constexpr Mod(int const MOD): m(MOD), s(std::__lg(std::max(m, 4u)-1)-1), n((1ULL<<s+33)/m) {} | |
}; | |
inline __m256i _mm256_mulmod_epu32(__m256i const x, __m256i const y, Mod const md) { | |
auto const p=_mm256_set1_epi32(md.m); | |
auto const q=_mm256_set1_epi64x(md.n); | |
auto const al=_mm256_mul_epu32(x, y); | |
auto const bl=_mm256_srli_epi64(al, md.s); | |
auto const cl=_mm256_srli_epi64(_mm256_mul_epu32(bl, q), 33); | |
auto const ah=_mm256_mul_epu32(_mm256_srli_si256(x, 4), _mm256_srli_si256(y, 4)); | |
auto const bh=_mm256_srli_epi64(ah, md.s); | |
auto const ch=_mm256_srli_epi64(_mm256_mul_epu32(bh, q), 33); | |
auto const dl=_mm256_sub_epi64(al, _mm256_mul_epu32(cl, p)); | |
auto const dh=_mm256_sub_epi64(ah, _mm256_mul_epu32(ch, p)); | |
auto const d=_mm256_or_si256(dl, _mm256_slli_si256(dh, 4)); | |
return _mm256_min_epu32(d, _mm256_sub_epi32(d, p)); | |
} | |
int main() { | |
// 나누는 수가 고정이면 constexpr, 아니면 const를 사용한다. | |
constexpr Mod md(987654321); | |
const int arr[8] = {1, 2, 3, 4, 5, 6, 7, 8}; | |
const int brr[8] = {1000000000, 1000000001, 1000000002, 1000000003, 1000000004, 1000000005, 1000000006, 1000000007}; | |
auto a=_mm256_loadu_si256((__m256i*)arr); | |
auto b=_mm256_loadu_si256((__m256i*)brr); | |
auto c=_mm256_mulmod_epu32(a, b, md); | |
int out[8]; | |
_mm256_storeu_si256((__m256i*)out, c); | |
printf("Result: "); | |
for(int i=0; i<8; i++) | |
printf("%d%c", out[i], " \n"[i==7]); | |
printf("Answer: "); | |
for(int i=0; i<8; i++) | |
printf("%d%c", 1LL*arr[i]*brr[i] % 987654321, " \n"[i==7]); | |
} | |
/* | |
Result: 12345679 24691360 37037043 49382728 61728415 74074104 86419795 98765488 | |
Answer: 12345679 24691360 37037043 49382728 61728415 74074104 86419795 98765488 | |
*/ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment