Last active
October 1, 2024 08:19
-
-
Save Um6ra1/34feac1c351755cdeba1565c5b932c81 to your computer and use it in GitHub Desktop.
NTT Multiply string
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// solve https://leetcode.com/problems/multiply-strings/description/ | |
#include "pch.h" | |
#include "CppUnitTest.h" | |
#include <string> | |
#include <vector> | |
#define REP_(i, a_, b_, a, b, ...) for (LL i = (a), lim##i = (b); i < lim##i; i++) | |
#define rep(i, ...) REP_(i, __VA_ARGS__, __VA_ARGS__, 0, __VA_ARGS__) | |
// p = 1 + a * 2^m | |
struct NTTConstruct { | |
int p; | |
int a; | |
int m; | |
int r; | |
NTTConstruct() {} | |
NTTConstruct(int m_) { | |
switch (m_) { | |
case 23: | |
p = 998244353, a = 119, m = 23, r = 3; break; | |
case 26: | |
p = 469762049, a = 7, m = 26, r = 3; break; | |
case 27: | |
p = 2013265921, a = 15, m = 27, r = 31; break; | |
} | |
} | |
}; | |
//NTTConstruct ntt_23 = { 998244353, 119, 23, 3}; // 9982... = 1+119*2^23, root=3 | |
//NTTConstruct ntt_23 = { 469762049 , 7, 26, 3 }; // 9982... = 1+119*2^23, root=3 | |
//NTTConstruct ntt_27 = { 2013265921, 15, 27, 31}; // = 1+15*2^27, root=3 | |
//NTTConstruct ntt_27 = { 469762049 , 7, 26, 3}; // = 1+15*2^27, root=3 | |
using LL = long long; | |
class NTT { | |
LL p_; // prime | |
std::vector<LL> omegas_; // { omega^0, omega^1, ... omega^N-1 } like for omega_N of FFT | |
std::vector<LL> iomegas_; // { omega^0, omega^-1, ... omega^-(N-1) } like for omega_N of FFT | |
public: | |
NTT(NTTConstruct nttc) : p_(nttc.p) { | |
omegas_.resize(nttc.m); | |
iomegas_.resize(nttc.m); | |
// r^p % p = 1 | |
// p = 1+A*2^m | |
// w_N = r^A | |
LL w = modpow(nttc.r, nttc.a, p_); | |
omegas_[0] = w; | |
rep(i, omegas_.size()) { | |
omegas_[omegas_.size() - 1 - i] = w; | |
iomegas_[omegas_.size() - 1 - i] = modpow(w, p_ - 2, p_); // x^(p-2)%p = 1/x | |
w = (w * w) % p_; | |
} | |
/*rep(i, omegas_.size()) { | |
omegas_[i] = w; | |
w = (w * w) % p_; | |
} | |
reverse(omegas_.begin(), omegas_.end());*/ | |
} | |
// x^y % m | |
LL modpow(LL x, LL y, LL m) { | |
LL a = 1; | |
while (y) { | |
if (y & 1) a = (a * x) % m; | |
x = (x * x) % m; | |
y >>= 1; | |
} | |
return a; | |
} | |
// dst: Output buffer, tmp: Temporary buffer, log2n: log2 of buffer length, isInv: Inverse FTT if true | |
// return: Destination buffer address if log2n is odd number or otherwise for NULL | |
LL* FFT1D_Stockham(LL* dst, LL* tmp, int log2n, bool isInv = false) { | |
auto dst0 = dst, tmp0 = tmp; | |
int m = log2n; | |
int n = 1 << m; | |
//Complex u(-1, 0); // exp(-i2pi/n * [0, 1, ... m-1]) | |
auto omegas = !isInv ? &omegas_[0] : &iomegas_[0]; // positive or negative rotation | |
rep(t, m) { | |
//double phase = -2.0 * M_PI / pow(2, t+1); | |
rep(j, 1 << (m - (t + 1))) { | |
//LL w = !isInv ? omegas_[t] : modpow(omegas_[t], p_ - 2, p_); // x^(p-2)%p = 1/x | |
LL w = omegas[t]; | |
//LL wn = w; | |
LL wn = 1; | |
//Complex w(1, 0); | |
rep(k, 1 << t) { | |
//auto a = Complex(cos(k*phase), -sin(k*phase)); | |
auto x1 = dst[j * (1 << t) + k]; | |
auto x2 = dst[j * (1 << t) + k + n / 2]; | |
tmp[j * (1 << (t + 1)) + k] = (x1 + (wn * x2) % p_) % p_; | |
//tmp[j * (1 << (t + 1)) + k + (1 << t)] = (x1 - (wn * x2)%p_)%p_; | |
tmp[j * (1 << (t + 1)) + k + (1 << t)] = (x1 + p_ - (wn * x2) % p_) % p_; | |
//w *= u; | |
wn = (wn * w) % p_; | |
} | |
} | |
//u.HalfAngOfUnitary(); | |
//if (dir == TdBackward) u.y *= -1; | |
std::swap(dst, tmp); | |
} | |
LL* dstBufAddrIfOddDepth = NULL; | |
if (m & 1) {// m is odd | |
memcpy(dst0, tmp0, sizeof(dst[0]) * n); | |
dstBufAddrIfOddDepth = tmp0; | |
} | |
if (isInv) { // divide by N=2^m is required | |
//double div = pow(n, -1); | |
//REP(i, n) dst0[i] *= div; | |
auto ninv = modpow(n, p_ - 2, p_); | |
rep(i, n) dst0[i] = (dst0[i] * ninv) % p_; | |
} | |
return dstBufAddrIfOddDepth; | |
} | |
void Mul(std::vector<LL>& dst, std::vector<LL>& a, std::vector<LL>& b) { | |
rep(i, dst.size()) dst[i] = a[i] * b[i] % p_; | |
} | |
}; | |
using namespace Microsoft::VisualStudio::CppUnitTestFramework; | |
using namespace std; | |
namespace UnitTest1 { | |
TEST_CLASS(UnitTest1) { | |
public: | |
TEST_METHOD(TestMethod1) { | |
auto multiply = [](string num1, string num2) -> string { | |
if(num1=="0" || num2 == "0") return "0"; | |
int strsize = num1.size() + num2.size(); | |
int log2n=ceil(log2(strsize)); | |
int fftsize = 1<<log2n; | |
NTT ntt(NTTConstruct(23)); | |
std::vector<LL> v1(fftsize); | |
std::vector<LL> v2(fftsize); | |
std::vector<LL> v3(fftsize); | |
std::vector<LL> tmp(fftsize); | |
rep(i, num1.size()) v1[i] = num1[num1.size() - 1 - i] - '0'; | |
rep(i, num2.size()) v2[i] = num2[num2.size() - 1 - i] - '0'; // buf overflow!? | |
ntt.FFT1D_Stockham(&v1[0], &tmp[0], log2n); | |
ntt.FFT1D_Stockham(&v2[0], &tmp[0], log2n); | |
//rep(i, fftsize) v3[i]=v1[i]*v2[i]%ntt_23.p; | |
ntt.Mul(v3, v1,v2); | |
ntt.FFT1D_Stockham(&v3[0], &tmp[0], log2n, true); | |
//int digitsNum = floor(1 + log10(num1.size()) + log10(num1.size())); | |
int carry=0; | |
string digits; | |
rep(i, strsize) { | |
int s = v3[i] + carry; | |
if(i>=strsize-1&&s==0)break; // avoid top zero | |
digits.push_back('0'+s % 10); | |
carry = s/10; | |
} | |
if(carry) | |
digits.push_back('0' + carry); | |
reverse(digits.begin(), digits.end()); | |
return digits; | |
}; | |
Assert::IsTrue(multiply("9133", "0") == "0"); | |
auto s5 = multiply("6", "501"); | |
Assert::IsTrue(s5 == "3006"); | |
auto s2 = multiply("123", "456"); | |
Assert::IsTrue(s2 == "56088"); | |
auto s4 = multiply("9", "99"); | |
Assert::IsTrue(s4 == "891"); | |
auto s3 = multiply("123456789", "987654321"); | |
Assert::IsTrue(s3 == "121932631112635269"); | |
auto s1 = multiply("2", "3"); | |
Assert::IsTrue(s1 == "6"); | |
Assert::IsTrue(multiply("0", "0") == "0"); | |
} | |
}; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment