Skip to content

Instantly share code, notes, and snippets.

@malb
Last active October 14, 2022 12:12
Show Gist options
  • Save malb/0bb42a3ab1983b6decff098dc61c9101 to your computer and use it in GitHub Desktop.
Save malb/0bb42a3ab1983b6decff098dc61c9101 to your computer and use it in GitHub Desktop.
NTTs for Power-of-two Rings

NTTs for Power-of-two Rings

from sage.all import ZZ, ceil, is_prime, parent, GF, vector, matrix

def omegaf(n, ell=None):
    return GF(qf(n, ell))(1).nth_root(n)

def qf(n, ell=None):
    """

    Return `q` s.t. `q % n ≡ 1` and `log_2(q) ≥ ell`

    :param n: a power of two
    :param ell: minimal bitsize of `q`

    EXAMPLES::

        sage: qf(8)
        sage: qf(8, logq=7)
    """
    if ell:
        i = ZZ(ceil((2**ell/n)))
    else:
        i = 1

    while not is_prime(i*n+1):
        i += 1
    q = i*n+1

    assert(q%n == 1)

    return q

NTT

def ntt_naive(v, w):
    """
    Number Theoretic Transform (naive implementation)

    :param v: a vector (in time domain)
    :param w: an n-th root

    EXAMPLE::

        sage: from util import qf, omegaf
        sage: from ntt import ntt_naive, ntt_dnc
        sage: v = random_vector(GF(qf(8)), 8)
        sage: ntt_naive(v, omegaf(8)) == ntt_dnc(v, omegaf(8))
        True

    """
    K = v.base_ring()
    n = len(v)
    v_ = PolynomialRing(K, "x")(v.list())
    return vector(K, n, [v_(w**i) for i in range(n)])


def intt_naive(v, w):
    """
    Inverse Number Theoretic Transform (naive implementation)

    :param v: a vector (in frequency domain)
    :param w: an n-th root

    EXAMPLE::

        sage: from util import qf, omegaf
        sage: from ntt import intt_naive, intt_dnc
        sage: v = random_vector(GF(qf(8)), 8)
        sage: intt_naive(v, omegaf(8)) == intt_dnc(v, omegaf(8))
        True

    """
    K = v.base_ring()
    n = len(v)
    v_ = PolynomialRing(K, "x")(v.list())
    return vector(K, n, [v_(w**-i)/K(n) for i in range(n)])


def ntt_dnc(v, w):
    """
    Number Theoretic Transform

    :param v: a vector (in frequency domain)
    :param w: an n-th root

    """
    n = len(v)
    K = v.base_ring()

    if n == 1:
        return v

    y = vector(K, [0 for _ in range(n)])
    y_e = ntt_dnc(v[0:n:2], w**2)
    y_o = ntt_dnc(v[1:n:2], w**2)
    for k in range(n/2):
        y[k]     = y_e[k] + w**k * y_o[k]
        y[k+n/2] = y_e[k] - w**k * y_o[k]

    return y


def intt_dnc(v, w):
    """
    Inverse Number Theoretic Transform.

    :param v: a vector (in time domain)
    :param w: an n-th root

    """
    n = len(v)
    K = v.base_ring()
    if n == 1:
        return v

    y = vector(K, [0 for _ in range(n)])
    y_e = intt_dnc(v[0::2], w**2)
    y_o = intt_dnc(v[1::2], w**2)
    for k in range(n/2):
        y[k]     = y_e[k] + w**-k * y_o[k]
        y[k+n/2] = y_e[k] - w**-k * y_o[k]

    return y/K(2)
v = random_vector(GF(qf(8)), 8)
ntt_naive(v, omegaf(8)) == ntt_dnc(v, omegaf(8))
True

Negacyclic NTT

def ncntt_naive(v, w):
    """
    Negacyclic Number Theoretic Transform (naive implementation)

    :param v: a vector (in time domain)
    :param w: an n-th root

    EXAMPLE::

        sage: from util import qf, omegaf
        sage: from ntt import ncntt_naive, ncntt_dnc
        sage: v = random_vector(GF(qf(16)), 8)
        sage: ncntt_naive(v, omegaf(16)) == ncntt_dnc(v, omegaf(16))
        True

    """
    K = v.base_ring()
    n = len(v)
    v = PolynomialRing(K, "x")([v_*w**i for i, v_ in enumerate(v)])
    r = [v(w**(2*i)) for i in range(n)]

    return vector(K, n, r)


def incntt_naive(v, w):
    """
    Negacyclic Inverse Number Theoretic Transform (naive implementation)

    :param v: a vector (in frequency domain)
    :param w: an n-th root

    EXAMPLE::

        sage: from util import qf, omegaf
        sage: from ntt import incntt_naive, incntt_dnc
        sage: v = random_vector(GF(qf(16)), 8)
        sage: incntt_naive(v, omegaf(16)) == incntt_dnc(v, omegaf(16))
        True

    """
    K = v.base_ring()
    n = len(v)
    v = PolynomialRing(K, "x")(v.list())
    v = [v(w**(-2*i))/K(n) for i in range(n)]
    v = [v_*w**-i for i, v_ in enumerate(v)]
    return vector(K, n, v)


def ncntt_seq(v, omega):
    """
    Negacyclic NTT (sequential implementation)

    :param v: the vector that we want to perform an ntt on
    :param omega: 2nth-root of unity

    EXAMPLE::

        sage: from util import qf, omegaf
        sage: from ntt import ncntt_seq, ncntt_dnc
        sage: v = random_vector(GF(qf(16)), 8)
        sage: ncntt_seq(v, omegaf(16)) == ncntt_dnc(v, omegaf(16))
        True

    """
    n = len(v)
    K = v.base_ring()
    w = vector(K, n)

    for i in range(n):
        for j in range(n):
            w[i] += (omega)**(2*i*j + j) * v[j]

    return w


def incntt_seq(w, omega):
    """
    Negacyclic inverse NTT

    :param w: vector that we want to inverse ntt
    :param omega: 2nth-root of unity

    EXAMPLE::

        sage: from util import qf, omegaf
        sage: from ntt import incntt_seq, incntt_dnc
        sage: v = random_vector(GF(qf(16)), 8)
        sage: incntt_seq(v, omegaf(16)) == incntt_dnc(v, omegaf(16))
        True

    """
    n = len(w)
    K = w.base_ring()
    v = vector(K, n)

    for i in range(n):
        for j in range(n):
            v[i] += omega**(-2*i*j -i) * w[j]
        v[i] = ~K(n) * v[i]

    return v


def ncntt_dnc(v, omega):
    """
    Negacyclic NTT (recursive implementation)

    :param v: the vector that we want to perform an ntt on
    :param omega: 2nth-root of unity
    """
    n = len(v)
    K = v.base_ring()
    w = vector(K, n)

    if n == 1:
        w = v
    else:
        w_e = ncntt_dnc(vector(v[0:n:2]), omega**2)
        w_o = ncntt_dnc(vector(v[1:n:2]), omega**2)
        for i in range(n/2):
            w[i]     = w_e[i] + omega**(2*i+1) * w_o[i]
            w[i+n/2] = w_e[i] - omega**(2*i+1) * w_o[i]

    return w


def incntt_dnc(w, omega):
    """
    Inverse Negacyclic NTT (recursive implementation)

    :param v: the vector that we want to perform an ntt on
    :param omega: 2nth-root of unity
    """
    n = len(w)
    K = w.base_ring()
    v = vector(K, n)

    if n == 1:
        v = w
    else:
        v_e = incntt_dnc(vector(w[0:n:2]), omega**2)
        v_o = incntt_dnc(vector(w[1:n:2]), omega**2)

        for i in range(n/2):
            v[i]     = (omega**i * v_e[i] + omega**(-i) * v_o[i]) / K(2)
            v[i+n/2] = (-omega**(i+n/2) * v_e[i] - omega**(-i-n/2)* v_o[i]) / K(2)

    return v
v = random_vector(GF(qf(16)), 8)
incntt_naive(v, omegaf(16)) == incntt_dnc(v, omegaf(16))
True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment