Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
GGH and Lattice Helpers
class GGHKeyPair:
def __init__(self, n=2, d=128, high_ratio_req=0.8, low_ratio_req=0.3):
self.generate(n, d, high_ratio_req, low_ratio_req)
self.n = n
# set_keypair can be used to set specific matrices for the public
# and private GGH keys.
def set_keypair(self, priv, pub):
self.public_key = pub
self.private_key = priv
self.n = len(pub[0])
def generate(self, n, d, high_ratio_req, low_ratio_req):
# generate random matrices until we have a matrix that is invertible
# and has a high enough hadamard ratio.
while True:
try:
V = random_matrix(ZZ, n, x=-d, y=d)
V.inverse()
if hadamard_ratio(V) < high_ratio_req:
# hadamard ratio was too low. try again.
continue
except ZeroDivisionError:
# non-invertible matrix. try again.
pass
else:
break
# multiple the public matrix by randomly generated unimodular matrices
# until the hadamard ratio reaches a certain "badness" threshold (i.e.
# a low enough hadamard ratio)
W = V
while hadamard_ratio(W) > low_ratio_req:
U = random_matrix(ZZ, n, upper_bound=d, algorithm='unimodular')
W = U*W
self.public_key = W
self.private_key = V
def encrypt(self, m, r=None, d=5):
message_length = len(m)
if r == None:
#Note: If the basis vectors are too small, r could cause the off-lattice vector (i.e. the ciphertext)
# to have several close on-lattice vectors (which will cause decryption false positives)
r = random_vector(ZZ, message_length, x=-d, y=d)
return (m*self.public_key) + r
def decrypt(self, c):
mW = babai(c, self.private_key)
return mW*self.public_key.inverse()
# test_ggh can be used to verify the correctness of the system.
def test_ggh(dim=2, n=10):
num_correct = 0
for i in range(n):
print("%d generating keypair..." % i)
k = GGHKeyPair(n=dim)
print("%d generating message..." % i)
m = random_vector(ZZ, k.n, x=0, y=3)
print("%d encrypting message..." % i)
c = k.encrypt(m)
print("%d decrypting message..." % i)
mp = k.decrypt(c)
if mp == m:
num_correct += 1
print("number correct: %d/%d" % (num_correct, n))
def hadamard_ratio(basis):
"""
hadamard_ratio returns the hadamard ratio of the supplied basis. the hadamard
ratio is a number between 0 and 1 that quantifies the orthogonality of the
baiss. The higher the hadamard ratio, the more orthogonal (read: better) the
basis.
"""
if len(basis.rows()) != len(basis.columns()):
raise Exception("no non-square matrices allowed")
BR = basis.change_ring(RR)
d = BR.determinant()
ls = [v.norm() for v in BR]
l = reduce(lambda x,y: 1.0*x*y, ls)
return (d/l)**(1.0/3.0)
def babai(w, basis):
"""
babai will approximate the closest lattice point formed by the supplied basis
to the supplied vector w. babai only works well when given a "good" basis.
"""
W = Matrix(RR, w)
BR = basis.change_ring(RR)
T = BR.solve_left(W)
C = vector(ZZ, [round(t) for t in T[0]])
return C*basis
def plot_2d_lattice(v1, v2, xmin=-10, xmax=10, ymin=-10, ymax=10, show_basis_vectors=True):
"""
plot_2d_lattice will return a sage plot of a lattice with v1 and v2 as basis vectors.
(x|y)(min|max) define the maximum and minimum x and y coordinates for the lattice points
on the plot. if show_basis_vectors is false, the basis vectors will not be shown on the
plot.
"""
pts = []
# plot all integer multiples of the basis so long as the x and y coordinates
# are within (x|y)(min|max).
for i in range(xmin, xmax):
for j in range(ymin, ymax):
pt = i*v1 + j*v2
x,y = pt[0], pt[1]
if x < xmin or x > xmax or y < ymin or y > ymax:
continue
pts.append(pt)
the_plot = plot(points(pts))
if show_basis_vectors:
the_plot += plot(v1) + plot(v2)
return the_plot
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment