Skip to content

Instantly share code, notes, and snippets.

@fredrik-johansson
Created April 5, 2021 12:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fredrik-johansson/2b8d6db8a2ffaf50c9ed97854abacd62 to your computer and use it in GitHub Desktop.
Save fredrik-johansson/2b8d6db8a2ffaf50c9ed97854abacd62 to your computer and use it in GitHub Desktop.
Some new nmod_poly multiplication code
#include "flint/nmod_poly.h"
#include "flint/profiler.h"
/*
Multiplication/squaring using Kronecker substitution at 2^b and -2^b.
*/
void
_nmod_poly_mul_KS2B(mp_ptr res, mp_srcptr op1, slong n1,
mp_srcptr op2, slong n2, nmod_t mod)
{
int sqr, v3m_neg;
ulong bits, b, w;
slong n1o, n1e, n2o, n2e, n3o, n3e, n3, k1, k2, k3;
mp_ptr v1_buf0, v2_buf0, v1_buf1, v2_buf1, v1_buf2, v2_buf2;
mp_ptr v1o, v1e, v1p, v1m, v2o, v2e, v2p, v2m, v3o, v3e, v3p, v3m;
mp_ptr z, tmp;
TMP_INIT;
if (n2 == 1)
{
/* code below needs n2 > 1, so fall back on scalar multiplication */
_nmod_vec_scalar_mul_nmod(res, op1, n1, op2[0], mod);
return;
}
TMP_START;
sqr = (op1 == op2 && n1 == n2);
/* bits in each output coefficient */
bits = 2 * (FLINT_BITS - mod.norm) + FLINT_CLOG2(n2);
/* we're evaluating at x = B and -B, where B = 2^b, and b = ceil(bits / 2) */
b = (bits + 1) / 2;
/* number of ulongs required to store each output coefficient */
w = (2*b - 1)/FLINT_BITS + 1;
/*
Write f1(x) = f1e(x^2) + x * f1o(x^2)
f2(x) = f2e(x^2) + x * f2o(x^2)
h(x) = he(x^2) + x * ho(x^2)
"e" = even, "o" = odd
*/
n1o = n1 / 2;
n1e = n1 - n1o;
n2o = n2 / 2;
n2e = n2 - n2o;
n3 = n1 + n2 - 1; /* length of h */
n3o = n3 / 2;
n3e = n3 - n3o;
/*
f1(B) and |f1(-B)| are at most ((n1 - 1) * b + mod->bits) bits long.
However, when evaluating f1e(B^2) and B * f1o(B^2) the bitpacking
routine needs room for the last chunk of 2b bits. Therefore we need to
allow room for (n1 + 1) * b bits. Ditto for f2.
*/
k1 = ((n1 + 1)*b - 1)/FLINT_BITS + 1;
k2 = ((n2 + 1)*b - 1)/FLINT_BITS + 1;
k3 = k1 + k2;
/* allocate space */
v1_buf0 = TMP_ALLOC(sizeof(mp_limb_t) * 3 * k3); /* k1 limbs */
v2_buf0 = v1_buf0 + k1; /* k2 limbs */
v1_buf1 = v2_buf0 + k2; /* k1 limbs */
v2_buf1 = v1_buf1 + k1; /* k2 limbs */
v1_buf2 = v2_buf1 + k2; /* k1 limbs */
v2_buf2 = v1_buf2 + k1; /* k2 limbs */
/*
arrange overlapping buffers to minimise memory use
"p" = plus, "m" = minus
*/
v1e = v1_buf0;
v2e = v2_buf0;
v1o = v1_buf1;
v2o = v2_buf1;
v1p = v1_buf2;
v2p = v2_buf2;
v1m = v1_buf0;
v2m = v2_buf0;
v3m = v1_buf1;
v3p = v1_buf0;
v3e = v1_buf2;
v3o = v1_buf0;
z = TMP_ALLOC(sizeof(mp_limb_t) * w * n3e);
if (!sqr)
{
/* multiplication version */
/* evaluate f1e(B^2) and B * f1o(B^2) */
_nmod_poly_KS2_pack(v1e, op1, n1e, 2, 2 * b, 0, k1);
_nmod_poly_KS2_pack(v1o, op1 + 1, n1o, 2, 2 * b, b, k1);
/* evaluate f2e(B^2) and B * f2o(B^2) */
_nmod_poly_KS2_pack(v2e, op2, n2e, 2, 2 * b, 0, k2);
_nmod_poly_KS2_pack(v2o, op2 + 1, n2o, 2, 2 * b, b, k2);
/*
compute f1(B) = f1e(B^2) + B * f1o(B^2)
and f2(B) = f2e(B^2) + B * f2o(B^2)
*/
mpn_add_n(v1p, v1e, v1o, k1);
mpn_add_n(v2p, v2e, v2o, k2);
/*
compute |f1(-B)| = |f1e(B^2) - B * f1o(B^2)|
and |f2(-B)| = |f2e(B^2) - B * f2o(B^2)|
*/
v3m_neg = signed_mpn_sub_n(v1m, v1e, v1o, k1);
v3m_neg ^= signed_mpn_sub_n(v2m, v2e, v2o, k2);
/*
compute h(B) = f1(B) * f2(B)
compute |h(-B)| = |f1(-B)| * |f2(-B)|
v3m_neg is set if h(-B) is negative
*/
mpn_mul(v3m, v1m, k1, v2m, k2);
mpn_mul(v3p, v1p, k1, v2p, k2);
}
else
{
/* squaring version */
/* evaluate f1e(B^2) and B * f1o(B^2) */
_nmod_poly_KS2_pack(v1e, op1, n1e, 2, 2 * b, 0, k1);
_nmod_poly_KS2_pack(v1o, op1 + 1, n1o, 2, 2 * b, b, k1);
/* compute f1(B) = f1e(B^2) + B * f1o(B^2) */
mpn_add_n(v1p, v1e, v1o, k1);
/* compute |f1(-B)| = |f1e(B^2) - B * f1o(B^2)| */
signed_mpn_sub_n(v1m, v1e, v1o, k1);
/*
compute h(B) = f1(B)^2
compute h(-B) = f1(-B)^2
v3m_neg is cleared (since f1(-B)^2 is never negative)
*/
mpn_sqr(v3m, v1m, k1);
mpn_sqr(v3p, v1p, k1);
v3m_neg = 0;
}
/*
he(B^2) and B * ho(B^2) are both at most b * (n3 + 1) bits long (since
the coefficients don't overlap). The buffers used below are at least
b * (n1 + n2 + 2) = b * (n3 + 3) bits long. So we definitely have
enough room for 2 * he(B^2) and 2 * B * ho(B^2).
*/
/* compute 2 * he(B^2) = h(B) + h(-B) */
if (v3m_neg)
mpn_sub_n(v3e, v3p, v3m, k3);
else
mpn_add_n(v3e, v3p, v3m, k3);
/* unpack coefficients of he, and reduce mod m */
_nmod_poly_KS2_unpack(z, v3e, n3e, 2 * b, 1);
_nmod_poly_KS2_reduce(res, 2, z, n3e, w, mod);
/* compute 2 * b * ho(B^2) = h(B) - h(-B) */
if (v3m_neg)
mpn_add_n(v3o, v3p, v3m, k3);
else
mpn_sub_n(v3o, v3p, v3m, k3);
/* unpack coefficients of ho, and reduce mod m */
_nmod_poly_KS2_unpack(z, v3o, n3o, 2 * b, b + 1);
_nmod_poly_KS2_reduce(res + 1, 2, z, n3o, w, mod);
TMP_END;
}
void
nmod_poly_mul_KS2B(nmod_poly_t res,
const nmod_poly_t poly1, const nmod_poly_t poly2)
{
slong len_out;
if ((poly1->length == 0) || (poly2->length == 0))
{
nmod_poly_zero(res);
return;
}
len_out = poly1->length + poly2->length - 1;
if (res == poly1 || res == poly2)
{
nmod_poly_t temp;
nmod_poly_init2_preinv(temp, poly1->mod.n, poly1->mod.ninv, len_out);
if (poly1->length >= poly2->length)
_nmod_poly_mul_KS2B(temp->coeffs, poly1->coeffs, poly1->length,
poly2->coeffs, poly2->length,
poly1->mod);
else
_nmod_poly_mul_KS2B(temp->coeffs, poly2->coeffs, poly2->length,
poly1->coeffs, poly1->length,
poly1->mod);
nmod_poly_swap(res, temp);
nmod_poly_clear(temp);
}
else
{
nmod_poly_fit_length(res, len_out);
if (poly1->length >= poly2->length)
_nmod_poly_mul_KS2B(res->coeffs, poly1->coeffs, poly1->length,
poly2->coeffs, poly2->length,
poly1->mod);
else
_nmod_poly_mul_KS2B(res->coeffs, poly2->coeffs, poly2->length,
poly1->coeffs, poly1->length,
poly1->mod);
}
res->length = len_out;
_nmod_poly_normalise(res);
}
static mp_limb_t
nmod_fmma(mp_limb_t a, mp_limb_t b, mp_limb_t c, mp_limb_t d, nmod_t mod)
{
a = nmod_mul(a, b, mod);
NMOD_ADDMUL(a, c, d, mod);
return a;
}
mp_limb_t
_nmod_vec_dot_rev(mp_srcptr vec1, mp_srcptr vec2, slong len, nmod_t mod, int nlimbs)
{
mp_limb_t res;
slong i;
if (len <= 2 && nlimbs >= 2)
{
if (len == 2)
return nmod_fmma(vec1[0], vec2[1], vec1[1], vec2[0], mod);
if (len == 1)
return nmod_mul(vec1[0], vec2[0], mod);
return 0;
}
NMOD_VEC_DOT(res, i, len, vec1[i], vec2[len - 1 - i], mod, nlimbs);
return res;
}
void
_nmod_poly_sqr_classical(mp_ptr res, mp_srcptr poly,
slong len, nmod_t mod)
{
slong i, j, bits, log_len, nlimbs, start, stop;
mp_limb_t c;
if (len == 1)
{
res[0] = nmod_mul(poly[0], poly[0], mod);
return;
}
if (len == 2)
{
mp_limb_t a, b, c;
a = poly[0];
b = poly[1];
c = nmod_mul(a, b, mod);
res[0] = nmod_mul(a, a, mod);
res[1] = nmod_add(c, c, mod);
res[2] = nmod_mul(b, b, mod);
return;
}
log_len = FLINT_BIT_COUNT(len);
bits = FLINT_BITS - (slong) mod.norm;
bits = 2 * bits + log_len;
if (bits <= FLINT_BITS)
{
flint_mpn_zero(res, 2 * len - 1);
for (i = 0; i < len; i++)
{
c = poly[i];
res[2 * i] += c * c;
c *= 2;
for (j = i + 1; j < len; j++)
res[i + j] += poly[j] * c;
}
_nmod_vec_reduce(res, res, 2 * len - 1, mod);
return;
}
if (bits <= 2 * FLINT_BITS)
nlimbs = 2;
else
nlimbs = 3;
for (i = 0; i < 2 * len - 1; i++)
{
start = FLINT_MAX(0, i - len + 1);
stop = FLINT_MIN(len - 1, (i + 1) / 2 - 1);
c = _nmod_vec_dot_rev(poly + start, poly + i - stop, stop - start + 1, mod, nlimbs);
c = nmod_add(c, c, mod);
if (i % 2 == 0 && i / 2 < len)
NMOD_ADDMUL(c, poly[i / 2], poly[i / 2], mod);
res[i] = c;
}
}
void
_nmod_poly_mul_classical2(mp_ptr res, mp_srcptr poly1,
slong len1, mp_srcptr poly2, slong len2, nmod_t mod)
{
slong i, j, bits, log_len, nlimbs, n1, n2;
int squaring;
mp_limb_t c;
if (len1 == 1)
{
res[0] = nmod_mul(poly1[0], poly2[0], mod);
return;
}
if (len2 == 1)
{
_nmod_vec_scalar_mul_nmod(res, poly1, len1, poly2[0], mod);
return;
}
squaring = (poly1 == poly2 && len1 == len2);
log_len = FLINT_BIT_COUNT(len2);
bits = FLINT_BITS - (slong) mod.norm;
bits = 2 * bits + log_len;
if (bits <= FLINT_BITS)
{
flint_mpn_zero(res, len1 + len2 - 1);
if (squaring)
{
for (i = 0; i < len1; i++)
{
c = poly1[i];
res[2 * i] += c * c;
c *= 2;
for (j = i + 1; j < len1; j++)
res[i + j] += poly1[j] * c;
}
}
else
{
for (i = 0; i < len1; i++)
{
mp_limb_t c = poly1[i];
for (j = 0; j < len2; j++)
res[i + j] += c * poly2[j];
}
}
_nmod_vec_reduce(res, res, len1 + len2 - 1, mod);
return;
}
if (len2 == 2)
{
_nmod_vec_scalar_mul_nmod(res, poly1, len1, poly2[0], mod);
_nmod_vec_scalar_addmul_nmod(res + 1, poly1, len1 - 1, poly2[1], mod);
res[len1 + len2 - 2] = nmod_mul(poly1[len1 - 1], poly2[len2 - 1], mod);
return;
}
if (bits <= 2 * FLINT_BITS)
nlimbs = 2;
else
nlimbs = 3;
if (squaring)
{
for (i = 0; i < 2 * len1 - 1; i++)
{
n1 = FLINT_MAX(0, i - len1 + 1);
n2 = FLINT_MIN(len1 - 1, (i + 1) / 2 - 1);
c = _nmod_vec_dot_rev(poly1 + n1, poly1 + i - n2, n2 - n1 + 1, mod, nlimbs);
c = nmod_add(c, c, mod);
if (i % 2 == 0 && i / 2 < len1)
NMOD_ADDMUL(c, poly1[i / 2], poly1[i / 2], mod);
res[i] = c;
}
}
else
{
for (i = 0; i < len1 + len2 - 1; i++)
{
n1 = FLINT_MIN(len1 - 1, i);
n2 = FLINT_MIN(len2 - 1, i);
res[i] = _nmod_vec_dot_rev(poly1 + i - n2,
poly2 + i - n1,
n1 + n2 - i + 1, mod, nlimbs);
}
}
}
void
nmod_poly_mul_classical2(nmod_poly_t res,
const nmod_poly_t poly1, const nmod_poly_t poly2)
{
slong len_out;
if ((poly1->length == 0) || (poly2->length == 0))
{
nmod_poly_zero(res);
return;
}
len_out = poly1->length + poly2->length - 1;
if (res == poly1 || res == poly2)
{
nmod_poly_t temp;
nmod_poly_init2_preinv(temp, poly1->mod.n, poly1->mod.ninv, len_out);
if (poly1->length >= poly2->length)
_nmod_poly_mul_classical2(temp->coeffs, poly1->coeffs,
poly1->length, poly2->coeffs,
poly2->length, poly1->mod);
else
_nmod_poly_mul_classical2(temp->coeffs, poly2->coeffs,
poly2->length, poly1->coeffs,
poly1->length, poly1->mod);
nmod_poly_swap(res, temp);
nmod_poly_clear(temp);
}
else
{
nmod_poly_fit_length(res, len_out);
if (poly1->length >= poly2->length)
_nmod_poly_mul_classical2(res->coeffs, poly1->coeffs, poly1->length,
poly2->coeffs, poly2->length, poly1->mod);
else
_nmod_poly_mul_classical2(res->coeffs, poly2->coeffs, poly2->length,
poly1->coeffs, poly1->length, poly1->mod);
}
res->length = len_out;
_nmod_poly_normalise(res);
}
flint_bitcnt_t _nmod_vec_max_bits2(mp_srcptr vec, slong len)
{
slong i;
mp_limb_t mask = 0;
for (i = 0; i < len; i++)
{
mask |= vec[i];
if (mask & (UWORD(1) << (FLINT_BITS - 1)))
return FLINT_BITS;
}
return FLINT_BIT_COUNT(mask);
}
void
_nmod_poly_mul_KSB(mp_ptr out, mp_srcptr in1, slong len1,
mp_srcptr in2, slong len2, flint_bitcnt_t bits, nmod_t mod)
{
slong len_out = len1 + len2 - 1, limbs1, limbs2;
mp_ptr tmp, mpn1, mpn2, res;
int squaring;
TMP_INIT;
squaring = (in1 == in2 && len1 == len2);
if (bits == 0)
{
flint_bitcnt_t bits1, bits2, loglen;
#if 0
bits1 = _nmod_vec_max_bits2(in1, len1);
bits2 = squaring ? bits1 : _nmod_vec_max_bits2(in2, len2);
#else
bits1 = FLINT_BITS - (slong) mod.norm;
bits2 = bits1;
#endif
loglen = FLINT_BIT_COUNT(len2);
bits = bits1 + bits2 + loglen;
}
limbs1 = (len1 * bits - 1) / FLINT_BITS + 1;
limbs2 = (len2 * bits - 1) / FLINT_BITS + 1;
TMP_START;
tmp = TMP_ALLOC(sizeof(mp_limb_t) * (limbs1 + limbs2 + limbs1 + (squaring ? 0 : limbs2)));
res = tmp;
mpn1 = tmp + limbs1 + limbs2;
mpn2 = squaring ? mpn1 : (mpn1 + limbs1);
_nmod_poly_bit_pack(mpn1, in1, len1, bits);
if (!squaring)
_nmod_poly_bit_pack(mpn2, in2, len2, bits);
if (squaring)
mpn_sqr(res, mpn1, limbs1);
else
mpn_mul(res, mpn1, limbs1, mpn2, limbs2);
_nmod_poly_bit_unpack(out, len_out, res, bits, mod);
TMP_END;
}
void
nmod_poly_mul_KSB(nmod_poly_t res,
const nmod_poly_t poly1, const nmod_poly_t poly2,
flint_bitcnt_t bits)
{
slong len_out;
if ((poly1->length == 0) || (poly2->length == 0))
{
nmod_poly_zero(res);
return;
}
len_out = poly1->length + poly2->length - 1;
if (res == poly1 || res == poly2)
{
nmod_poly_t temp;
nmod_poly_init2_preinv(temp, poly1->mod.n, poly1->mod.ninv, len_out);
if (poly1->length >= poly2->length)
_nmod_poly_mul_KSB(temp->coeffs, poly1->coeffs, poly1->length,
poly2->coeffs, poly2->length, bits,
poly1->mod);
else
_nmod_poly_mul_KSB(temp->coeffs, poly2->coeffs, poly2->length,
poly1->coeffs, poly1->length, bits,
poly1->mod);
nmod_poly_swap(res, temp);
nmod_poly_clear(temp);
}
else
{
nmod_poly_fit_length(res, len_out);
if (poly1->length >= poly2->length)
_nmod_poly_mul_KSB(res->coeffs, poly1->coeffs, poly1->length,
poly2->coeffs, poly2->length, bits,
poly1->mod);
else
_nmod_poly_mul_KSB(res->coeffs, poly2->coeffs, poly2->length,
poly1->coeffs, poly1->length, bits,
poly1->mod);
}
res->length = len_out;
_nmod_poly_normalise(res);
}
#define TIMEIT_PRINT1(__var, __timer, __reps) \
__var = __timer->cpu*0.001/__reps;
#define TIMEIT_REPEAT1(__timer, __reps) \
do \
{ \
slong __timeit_k; \
__reps = 1; \
while (1) \
{ \
timeit_start(__timer); \
for (__timeit_k = 0; __timeit_k < __reps; __timeit_k++) \
{
#define TIMEIT_END_REPEAT1(__timer, __reps) \
} \
timeit_stop(__timer); \
if (__timer->cpu >= 10) \
break; \
__reps *= 10; \
} \
} while (0);
#define TIMEIT_START1 \
do { \
timeit_t __timer; slong __reps; \
TIMEIT_REPEAT1(__timer, __reps)
#define TIMEIT_STOP1(__var) \
TIMEIT_END_REPEAT1(__timer, __reps) \
TIMEIT_PRINT1(__var, __timer, __reps) \
} while (0);
static int choose_KS2(slong bits, slong len)
{
if (len * bits < 800)
return 1;
if (len * bits * bits < 100000 * (1 + (FLINT_BITS >= 62)))
return 2;
return 4;
}
static int choose_KS(slong bits, slong len)
{
if (len * bits < 800)
return 1;
if (len * bits * bits < 100000)
return 2;
return 4;
}
void _nmod_poly_mul2(mp_ptr res, mp_srcptr poly1, slong len1,
mp_srcptr poly2, slong len2, nmod_t mod)
{
int KS;
slong bits, cutoff_len;
if (len2 <= 5)
{
_nmod_poly_mul_classical2(res, poly1, len1, poly2, len2, mod);
return;
}
bits = FLINT_BITS - (slong) mod.norm;
cutoff_len = FLINT_MIN(len1, 2 * len2);
if (3 * cutoff_len < 2 * FLINT_MAX(bits, 10))
_nmod_poly_mul_classical2(res, poly1, len1, poly2, len2, mod);
else if (cutoff_len * bits < 800)
_nmod_poly_mul_KSB(res, poly1, len1, poly2, len2, 0, mod);
else if (cutoff_len * (bits + 1) * (bits + 1) < 100000)
_nmod_poly_mul_KS2B(res, poly1, len1, poly2, len2, mod);
else
_nmod_poly_mul_KS4(res, poly1, len1, poly2, len2, mod);
return;
/*
slong bits2;
bits2 = FLINT_BIT_COUNT(len1);
if (2 * bits + bits2 <= FLINT_BITS && len1 + len2 < 16)
_nmod_poly_mul_classical2(res, poly1, len1, poly2, len2, mod);
else if (bits * len2 > 2000)
_nmod_poly_mul_KS4(res, poly1, len1, poly2, len2, mod);
else if (bits * len2 > 200)
_nmod_poly_mul_KS2B(res, poly1, len1, poly2, len2, mod);
else
_nmod_poly_mul_KSB(res, poly1, len1, poly2, len2, 0, mod);
return;
*/
/* Note: with unbalanced operands, KS tuning seems to respond better to the
length of the longer operand? */
KS = choose_KS(bits, len1);
if (KS == 1)
_nmod_poly_mul_KSB(res, poly1, len1, poly2, len2, 0, mod);
else if (KS == 2)
_nmod_poly_mul_KS2B(res, poly1, len1, poly2, len2, mod);
else
_nmod_poly_mul_KS4(res, poly1, len1, poly2, len2, mod);
}
void nmod_poly_mul2(nmod_poly_t res, const nmod_poly_t poly1, const nmod_poly_t poly2)
{
slong len1, len2, len_out;
len1 = poly1->length;
len2 = poly2->length;
if (len1 == 0 || len2 == 0)
{
nmod_poly_zero(res);
return;
}
len_out = poly1->length + poly2->length - 1;
if (res == poly1 || res == poly2)
{
nmod_poly_t temp;
nmod_poly_init2(temp, poly1->mod.n, len_out);
if (len1 >= len2)
_nmod_poly_mul2(temp->coeffs, poly1->coeffs, len1,
poly2->coeffs, len2, poly1->mod);
else
_nmod_poly_mul2(temp->coeffs, poly2->coeffs, len2,
poly1->coeffs, len1, poly1->mod);
nmod_poly_swap(temp, res);
nmod_poly_clear(temp);
} else
{
nmod_poly_fit_length(res, len_out);
if (len1 >= len2)
_nmod_poly_mul2(res->coeffs, poly1->coeffs, len1,
poly2->coeffs, len2, poly1->mod);
else
_nmod_poly_mul2(res->coeffs, poly2->coeffs, len2,
poly1->coeffs, len1, poly1->mod);
}
res->length = len_out;
_nmod_poly_normalise(res);
}
void
_nmod_poly_mullow_classical2(mp_ptr res, mp_srcptr poly1,
slong len1, mp_srcptr poly2, slong len2, slong n, nmod_t mod)
{
slong i, j, bits, log_len, nlimbs, n1, n2;
int squaring;
mp_limb_t c;
len1 = FLINT_MIN(len1, n);
len2 = FLINT_MIN(len2, n);
if (n == 1)
{
res[0] = nmod_mul(poly1[0], poly2[0], mod);
return;
}
if (len2 == 1)
{
_nmod_vec_scalar_mul_nmod(res, poly1, len1, poly2[0], mod);
return;
}
squaring = (poly1 == poly2 && len1 == len2);
log_len = FLINT_BIT_COUNT(len2);
bits = FLINT_BITS - (slong) mod.norm;
bits = 2 * bits + log_len;
if (bits <= FLINT_BITS)
{
flint_mpn_zero(res, n);
if (squaring)
{
for (i = 0; i < len1; i++)
{
c = poly1[i];
if (2 * i < n)
res[2 * i] += c * c;
c *= 2;
for (j = i + 1; j < FLINT_MIN(len1, n - i); j++)
res[i + j] += poly1[j] * c;
}
}
else
{
for (i = 0; i < len1; i++)
{
mp_limb_t c = poly1[i];
for (j = 0; j < FLINT_MIN(len2, n - i); j++)
res[i + j] += c * poly2[j];
}
}
_nmod_vec_reduce(res, res, n, mod);
return;
}
if (len2 == 2)
{
_nmod_vec_scalar_mul_nmod(res, poly1, len1, poly2[0], mod);
_nmod_vec_scalar_addmul_nmod(res + 1, poly1, len1 - 1, poly2[1], mod);
if (n == len1 + len2 - 1)
res[len1 + len2 - 2] = nmod_mul(poly1[len1 - 1], poly2[len2 - 1], mod);
return;
}
if (bits <= 2 * FLINT_BITS)
nlimbs = 2;
else
nlimbs = 3;
if (squaring)
{
for (i = 0; i < n; i++)
{
n1 = FLINT_MAX(0, i - len1 + 1);
n2 = FLINT_MIN(len1 - 1, (i + 1) / 2 - 1);
c = _nmod_vec_dot_rev(poly1 + n1, poly1 + i - n2, n2 - n1 + 1, mod, nlimbs);
c = nmod_add(c, c, mod);
if (i % 2 == 0 && i / 2 < len1)
NMOD_ADDMUL(c, poly1[i / 2], poly1[i / 2], mod);
res[i] = c;
}
}
else
{
for (i = 0; i < n; i++)
{
n1 = FLINT_MIN(len1 - 1, i);
n2 = FLINT_MIN(len2 - 1, i);
res[i] = _nmod_vec_dot_rev(poly1 + i - n2,
poly2 + i - n1,
n1 + n2 - i + 1, mod, nlimbs);
}
}
}
void
_nmod_poly_mullow_KSB(mp_ptr out, mp_srcptr in1, slong len1,
mp_srcptr in2, slong len2, flint_bitcnt_t bits, slong n, nmod_t mod)
{
slong limbs1, limbs2;
mp_ptr tmp, mpn1, mpn2, res;
int squaring;
TMP_INIT;
len1 = FLINT_MIN(len1, n);
len2 = FLINT_MIN(len2, n);
squaring = (in1 == in2 && len1 == len2);
if (bits == 0)
{
flint_bitcnt_t bits1, bits2, loglen;
#if 0
bits1 = _nmod_vec_max_bits2(in1, len1);
bits2 = squaring ? bits1 : _nmod_vec_max_bits2(in2, len2);
#else
bits1 = FLINT_BITS - (slong) mod.norm;
bits2 = bits1;
#endif
loglen = FLINT_BIT_COUNT(len2);
bits = bits1 + bits2 + loglen;
}
limbs1 = (len1 * bits - 1) / FLINT_BITS + 1;
limbs2 = (len2 * bits - 1) / FLINT_BITS + 1;
TMP_START;
tmp = TMP_ALLOC(sizeof(mp_limb_t) * (limbs1 + limbs2 + limbs1 + (squaring ? 0 : limbs2)));
res = tmp;
mpn1 = tmp + limbs1 + limbs2;
mpn2 = squaring ? mpn1 : (mpn1 + limbs1);
_nmod_poly_bit_pack(mpn1, in1, len1, bits);
if (!squaring)
_nmod_poly_bit_pack(mpn2, in2, len2, bits);
if (squaring)
mpn_sqr(res, mpn1, limbs1);
else
mpn_mul(res, mpn1, limbs1, mpn2, limbs2);
_nmod_poly_bit_unpack(out, n, res, bits, mod);
TMP_END;
}
void _nmod_poly_mullow2(mp_ptr res, mp_srcptr poly1, slong len1,
mp_srcptr poly2, slong len2, slong n, nmod_t mod)
{
slong bits;
len1 = FLINT_MIN(len1, n);
len2 = FLINT_MIN(len2, n);
if (len2 <= 5)
{
_nmod_poly_mullow_classical2(res, poly1, len1, poly2, len2, n, mod);
return;
}
bits = FLINT_BITS - (slong) mod.norm;
if (n < 10 + bits * bits / 10)
_nmod_poly_mullow_classical2(res, poly1, len1, poly2, len2, n, mod);
else
_nmod_poly_mullow_KSB(res, poly1, len1, poly2, len2, 0, n, mod);
}
void nmod_poly_mullow2(nmod_poly_t res,
const nmod_poly_t poly1, const nmod_poly_t poly2, slong trunc)
{
slong len1, len2, len_out;
len1 = poly1->length;
len2 = poly2->length;
len_out = poly1->length + poly2->length - 1;
if (trunc > len_out)
trunc = len_out;
if (len1 == 0 || len2 == 0 || trunc == 0)
{
nmod_poly_zero(res);
return;
}
if (res == poly1 || res == poly2)
{
nmod_poly_t temp;
nmod_poly_init2(temp, poly1->mod.n, trunc);
if (len1 >= len2)
_nmod_poly_mullow2(temp->coeffs, poly1->coeffs, len1,
poly2->coeffs, len2, trunc, poly1->mod);
else
_nmod_poly_mullow2(temp->coeffs, poly2->coeffs, len2,
poly1->coeffs, len1, trunc, poly1->mod);
nmod_poly_swap(temp, res);
nmod_poly_clear(temp);
} else
{
nmod_poly_fit_length(res, trunc);
if (len1 >= len2)
_nmod_poly_mullow2(res->coeffs, poly1->coeffs, len1,
poly2->coeffs, len2, trunc, poly1->mod);
else
_nmod_poly_mullow2(res->coeffs, poly2->coeffs, len2,
poly1->coeffs, len1, trunc, poly1->mod);
}
res->length = trunc;
_nmod_poly_normalise(res);
}
void nmod_poly_mullow_classical2(nmod_poly_t res,
const nmod_poly_t poly1, const nmod_poly_t poly2, slong trunc)
{
slong len1, len2, len_out;
len1 = poly1->length;
len2 = poly2->length;
len_out = poly1->length + poly2->length - 1;
if (trunc > len_out)
trunc = len_out;
if (len1 == 0 || len2 == 0 || trunc == 0)
{
nmod_poly_zero(res);
return;
}
if (res == poly1 || res == poly2)
{
nmod_poly_t temp;
nmod_poly_init2(temp, poly1->mod.n, trunc);
if (len1 >= len2)
_nmod_poly_mullow_classical2(temp->coeffs, poly1->coeffs, len1,
poly2->coeffs, len2, trunc, poly1->mod);
else
_nmod_poly_mullow_classical2(temp->coeffs, poly2->coeffs, len2,
poly1->coeffs, len1, trunc, poly1->mod);
nmod_poly_swap(temp, res);
nmod_poly_clear(temp);
} else
{
nmod_poly_fit_length(res, trunc);
if (len1 >= len2)
_nmod_poly_mullow_classical2(res->coeffs, poly1->coeffs, len1,
poly2->coeffs, len2, trunc, poly1->mod);
else
_nmod_poly_mullow_classical2(res->coeffs, poly2->coeffs, len2,
poly1->coeffs, len1, trunc, poly1->mod);
}
res->length = trunc;
_nmod_poly_normalise(res);
}
void nmod_poly_mullow_KSB(nmod_poly_t res,
const nmod_poly_t poly1, const nmod_poly_t poly2, slong trunc)
{
slong len1, len2, len_out;
len1 = poly1->length;
len2 = poly2->length;
len_out = poly1->length + poly2->length - 1;
if (trunc > len_out)
trunc = len_out;
if (len1 == 0 || len2 == 0 || trunc == 0)
{
nmod_poly_zero(res);
return;
}
if (res == poly1 || res == poly2)
{
nmod_poly_t temp;
nmod_poly_init2(temp, poly1->mod.n, trunc);
if (len1 >= len2)
_nmod_poly_mullow_KSB(temp->coeffs, poly1->coeffs, len1,
poly2->coeffs, len2, 0, trunc, poly1->mod);
else
_nmod_poly_mullow_KSB(temp->coeffs, poly2->coeffs, len2,
poly1->coeffs, len1, 0, trunc, poly1->mod);
nmod_poly_swap(temp, res);
nmod_poly_clear(temp);
} else
{
nmod_poly_fit_length(res, trunc);
if (len1 >= len2)
_nmod_poly_mullow_KSB(res->coeffs, poly1->coeffs, len1,
poly2->coeffs, len2, 0, trunc, poly1->mod);
else
_nmod_poly_mullow_KSB(res->coeffs, poly2->coeffs, len2,
poly1->coeffs, len1, 0, trunc, poly1->mod);
}
res->length = trunc;
_nmod_poly_normalise(res);
}
int checkbits[] = { 2, 4, 8, 16, 28, 32, 60, 64, 0 };
/*
int checkbits[] = { 64, 60, 32, 28, 16, 8, 4, 2, 0 };
int checkbits[] = { 2, 4, 8, 16, 32, 64, 0 };
int checkbits[] = { 64, 8, 2, 0 };
*/
void
randpoly(nmod_poly_t f, flint_rand_t state, slong n)
{
slong i;
nmod_poly_zero(f);
for (i = 0; i < n; i++)
nmod_poly_set_coeff_ui(f, i, n_randlimb(state) % f->mod.n);
if (f->length < n)
nmod_poly_set_coeff_ui(f, n - 1, 1);
}
#define TIMET(res, expr) \
TIMEIT_START1 expr; TIMEIT_STOP1(tx) \
TIMEIT_START1 expr; TIMEIT_STOP1(ty) \
TIMEIT_START1 expr; TIMEIT_STOP1(tz) \
res = FLINT_MIN(tx, FLINT_MIN(ty, tz)); \
int main()
{
nmod_t mod;
nmod_poly_t f, g, h;
flint_rand_t state;
flint_randinit(state);
slong i, j, n, ii, bits;
double t1, t2, tt;
slong iter;
slong iters = 1000;
slong iters2 = 20;
for (ii = 0; (bits = checkbits[ii]) != 0; ii++)
{
for (n = 1; n <= 30000; n = FLINT_MAX(n+1, n*1.1))
{
double tx, ty, tz, told, tnew, told2, tnew2, told10, tnew10, tolds, tnews;
if (bits == 64)
nmod_init(&mod, UWORD_MAX);
else
nmod_init(&mod, (UWORD(1) << bits) - UWORD(1));
nmod_poly_init(f, mod.n);
nmod_poly_init(g, mod.n);
nmod_poly_init(h, mod.n);
printf("%ld %ld ", bits, n); fflush(stdout);
randpoly(f, state, n);
randpoly(g, state, n);
TIMET(told, nmod_poly_mullow(h, f, g, n));
TIMET(tnew, nmod_poly_mullow2(h, f, g, n));
printf("%.3f ", told / tnew);
randpoly(f, state, n);
randpoly(g, state, FLINT_MAX(n / 2, 1));
TIMET(told, nmod_poly_mullow(h, f, g, n));
TIMET(tnew, nmod_poly_mullow2(h, f, g, n));
printf("%.3f ", told / tnew);
randpoly(f, state, n);
TIMET(told, nmod_poly_mullow(h, f, f, n));
TIMET(tnew, nmod_poly_mullow2(h, f, f, n));
printf("%.3f ", told / tnew);
randpoly(f, state, n);
randpoly(g, state, n);
TIMET(told, nmod_poly_mul(h, f, g));
TIMET(tnew, nmod_poly_mul2(h, f, g));
printf("%.3f ", told / tnew);
randpoly(f, state, n);
randpoly(g, state, 2 * n);
TIMET(told, nmod_poly_mul(h, f, g));
TIMET(tnew, nmod_poly_mul2(h, f, g));
printf("%.3f ", told / tnew);
randpoly(f, state, 10 * n);
randpoly(g, state, n);
TIMET(told, nmod_poly_mul(h, f, g));
TIMET(tnew, nmod_poly_mul2(h, f, g));
printf("%.3f ", told / tnew);
randpoly(f, state, n);
TIMET(told, nmod_poly_mul(h, f, f));
TIMET(tnew, nmod_poly_mul2(h, f, f));
printf("%.3f ", told / tnew);
printf("\n");
nmod_poly_clear(f);
nmod_poly_clear(g);
nmod_poly_clear(h);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment