Skip to content

Instantly share code, notes, and snippets.

@carl-mastrangelo
Last active October 8, 2022 17:37
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save carl-mastrangelo/2926b8024d5a2cef53808ee41736777c to your computer and use it in GitHub Desktop.
Save carl-mastrangelo/2926b8024d5a2cef53808ee41736777c to your computer and use it in GitHub Desktop.
NTRU Prime encoder
# From page 16 of https://ntruprime.cr.yp.to/nist/ntruprime-20201007.pdf
# Linked from https://www.imperialviolet.org/2021/08/26/qrencoding.html
def rebase(innums, indenoms, limit, newbase):
if len(innums) != len(indenoms):
raise "Bad"
syms = []
if len(innums) == 1:
num, denom = innums[0], indenoms[0]
while denom > 1:
syms += [num % newbase]
num, denom = num // newbase, (denom + newbase - 1) // newbase
return syms
newnums, newdenoms = [], []
for i in range(0, len(indenoms) - 1, 2):
num = innums[i] + innums[i+1] * indenoms[i]
denom = indenoms[i] * indenoms[i+1]
while denom >= limit:
syms += [num % newbase]
num, denom = num // newbase, (denom + newbase - 1) // newbase
newnums += [num]
newdenoms += [denom]
if len(indenoms) & 1 == 1:
newnums += [innums[-1]]
newdenoms += [indenoms[-1]]
return syms + rebase(newnums, newdenoms, limit, newbase)
def unrebase(syms, denoms, limit, base):
if len(denoms) == 0:
return []
if len(denoms) == 1:
count = 0
for i in range(len(syms)):
count += syms[i] * (base**i)
return [count % denoms[0]]
k = 0 # symbols consumed
bottom, newdenoms = [], []
for i in range(0, len(denoms) - 1, 2):
denom, num, t = denoms[i] * denoms[i+1], 0, 1
while denom >= limit:
num, t, k, denom = num+syms[k]*t, t * base, k + 1, (denom + base - 1)//base
bottom += [(num, t)]
newdenoms += [denom]
if len(denoms)&1 == 1:
newdenoms += [denoms[-1]]
newnums = unrebase(syms[k:], newdenoms, limit, base)
nums = []
for i in range(0, len(denoms) - 1, 2):
num, t = bottom[i//2]
num += t * newnums[i//2]
nums += [num % denoms[i]]
nums += [(num//denoms[i]) % denoms[i+1]]
if len(denoms)&1 == 1:
nums += [newnums[-1]]
return nums
encoded = rebase([255, 0, 255, 255, 255, 255, 255, 254], [256]*8, 58*58, 58)
print(encoded)
decoded = unrebase(encoded, [256]*8, 58*58, 58)
print(decoded)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment