Skip to content

Instantly share code, notes, and snippets.

@Um6ra1
Last active October 1, 2024 08:19
Show Gist options
  • Save Um6ra1/34feac1c351755cdeba1565c5b932c81 to your computer and use it in GitHub Desktop.
Save Um6ra1/34feac1c351755cdeba1565c5b932c81 to your computer and use it in GitHub Desktop.
NTT Multiply string
// solve https://leetcode.com/problems/multiply-strings/description/
#include "pch.h"
#include "CppUnitTest.h"
#include <string>
#include <vector>
#define REP_(i, a_, b_, a, b, ...) for (LL i = (a), lim##i = (b); i < lim##i; i++)
#define rep(i, ...) REP_(i, __VA_ARGS__, __VA_ARGS__, 0, __VA_ARGS__)
// p = 1 + a * 2^m
struct NTTConstruct {
int p;
int a;
int m;
int r;
NTTConstruct() {}
NTTConstruct(int m_) {
switch (m_) {
case 23:
p = 998244353, a = 119, m = 23, r = 3; break;
case 26:
p = 469762049, a = 7, m = 26, r = 3; break;
case 27:
p = 2013265921, a = 15, m = 27, r = 31; break;
}
}
};
//NTTConstruct ntt_23 = { 998244353, 119, 23, 3}; // 9982... = 1+119*2^23, root=3
//NTTConstruct ntt_23 = { 469762049 , 7, 26, 3 }; // 9982... = 1+119*2^23, root=3
//NTTConstruct ntt_27 = { 2013265921, 15, 27, 31}; // = 1+15*2^27, root=3
//NTTConstruct ntt_27 = { 469762049 , 7, 26, 3}; // = 1+15*2^27, root=3
using LL = long long;
class NTT {
LL p_; // prime
std::vector<LL> omegas_; // { omega^0, omega^1, ... omega^N-1 } like for omega_N of FFT
std::vector<LL> iomegas_; // { omega^0, omega^-1, ... omega^-(N-1) } like for omega_N of FFT
public:
NTT(NTTConstruct nttc) : p_(nttc.p) {
omegas_.resize(nttc.m);
iomegas_.resize(nttc.m);
// r^p % p = 1
// p = 1+A*2^m
// w_N = r^A
LL w = modpow(nttc.r, nttc.a, p_);
omegas_[0] = w;
rep(i, omegas_.size()) {
omegas_[omegas_.size() - 1 - i] = w;
iomegas_[omegas_.size() - 1 - i] = modpow(w, p_ - 2, p_); // x^(p-2)%p = 1/x
w = (w * w) % p_;
}
/*rep(i, omegas_.size()) {
omegas_[i] = w;
w = (w * w) % p_;
}
reverse(omegas_.begin(), omegas_.end());*/
}
// x^y % m
LL modpow(LL x, LL y, LL m) {
LL a = 1;
while (y) {
if (y & 1) a = (a * x) % m;
x = (x * x) % m;
y >>= 1;
}
return a;
}
// dst: Output buffer, tmp: Temporary buffer, log2n: log2 of buffer length, isInv: Inverse FTT if true
// return: Destination buffer address if log2n is odd number or otherwise for NULL
LL* FFT1D_Stockham(LL* dst, LL* tmp, int log2n, bool isInv = false) {
auto dst0 = dst, tmp0 = tmp;
int m = log2n;
int n = 1 << m;
//Complex u(-1, 0); // exp(-i2pi/n * [0, 1, ... m-1])
auto omegas = !isInv ? &omegas_[0] : &iomegas_[0]; // positive or negative rotation
rep(t, m) {
//double phase = -2.0 * M_PI / pow(2, t+1);
rep(j, 1 << (m - (t + 1))) {
//LL w = !isInv ? omegas_[t] : modpow(omegas_[t], p_ - 2, p_); // x^(p-2)%p = 1/x
LL w = omegas[t];
//LL wn = w;
LL wn = 1;
//Complex w(1, 0);
rep(k, 1 << t) {
//auto a = Complex(cos(k*phase), -sin(k*phase));
auto x1 = dst[j * (1 << t) + k];
auto x2 = dst[j * (1 << t) + k + n / 2];
tmp[j * (1 << (t + 1)) + k] = (x1 + (wn * x2) % p_) % p_;
//tmp[j * (1 << (t + 1)) + k + (1 << t)] = (x1 - (wn * x2)%p_)%p_;
tmp[j * (1 << (t + 1)) + k + (1 << t)] = (x1 + p_ - (wn * x2) % p_) % p_;
//w *= u;
wn = (wn * w) % p_;
}
}
//u.HalfAngOfUnitary();
//if (dir == TdBackward) u.y *= -1;
std::swap(dst, tmp);
}
LL* dstBufAddrIfOddDepth = NULL;
if (m & 1) {// m is odd
memcpy(dst0, tmp0, sizeof(dst[0]) * n);
dstBufAddrIfOddDepth = tmp0;
}
if (isInv) { // divide by N=2^m is required
//double div = pow(n, -1);
//REP(i, n) dst0[i] *= div;
auto ninv = modpow(n, p_ - 2, p_);
rep(i, n) dst0[i] = (dst0[i] * ninv) % p_;
}
return dstBufAddrIfOddDepth;
}
void Mul(std::vector<LL>& dst, std::vector<LL>& a, std::vector<LL>& b) {
rep(i, dst.size()) dst[i] = a[i] * b[i] % p_;
}
};
using namespace Microsoft::VisualStudio::CppUnitTestFramework;
using namespace std;
namespace UnitTest1 {
TEST_CLASS(UnitTest1) {
public:
TEST_METHOD(TestMethod1) {
auto multiply = [](string num1, string num2) -> string {
if(num1=="0" || num2 == "0") return "0";
int strsize = num1.size() + num2.size();
int log2n=ceil(log2(strsize));
int fftsize = 1<<log2n;
NTT ntt(NTTConstruct(23));
std::vector<LL> v1(fftsize);
std::vector<LL> v2(fftsize);
std::vector<LL> v3(fftsize);
std::vector<LL> tmp(fftsize);
rep(i, num1.size()) v1[i] = num1[num1.size() - 1 - i] - '0';
rep(i, num2.size()) v2[i] = num2[num2.size() - 1 - i] - '0'; // buf overflow!?
ntt.FFT1D_Stockham(&v1[0], &tmp[0], log2n);
ntt.FFT1D_Stockham(&v2[0], &tmp[0], log2n);
//rep(i, fftsize) v3[i]=v1[i]*v2[i]%ntt_23.p;
ntt.Mul(v3, v1,v2);
ntt.FFT1D_Stockham(&v3[0], &tmp[0], log2n, true);
//int digitsNum = floor(1 + log10(num1.size()) + log10(num1.size()));
int carry=0;
string digits;
rep(i, strsize) {
int s = v3[i] + carry;
if(i>=strsize-1&&s==0)break; // avoid top zero
digits.push_back('0'+s % 10);
carry = s/10;
}
if(carry)
digits.push_back('0' + carry);
reverse(digits.begin(), digits.end());
return digits;
};
Assert::IsTrue(multiply("9133", "0") == "0");
auto s5 = multiply("6", "501");
Assert::IsTrue(s5 == "3006");
auto s2 = multiply("123", "456");
Assert::IsTrue(s2 == "56088");
auto s4 = multiply("9", "99");
Assert::IsTrue(s4 == "891");
auto s3 = multiply("123456789", "987654321");
Assert::IsTrue(s3 == "121932631112635269");
auto s1 = multiply("2", "3");
Assert::IsTrue(s1 == "6");
Assert::IsTrue(multiply("0", "0") == "0");
}
};
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment