Created
December 17, 2016 23:53
-
-
Save kelbyludwig/201d08e3e8e9a4f3764f366398f12a47 to your computer and use it in GitHub Desktop.
GGH and Lattice Helpers
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class GGHKeyPair: | |
def __init__(self, n=2, d=128, high_ratio_req=0.8, low_ratio_req=0.3): | |
self.generate(n, d, high_ratio_req, low_ratio_req) | |
self.n = n | |
# set_keypair can be used to set specific matrices for the public | |
# and private GGH keys. | |
def set_keypair(self, priv, pub): | |
self.public_key = pub | |
self.private_key = priv | |
self.n = len(pub[0]) | |
def generate(self, n, d, high_ratio_req, low_ratio_req): | |
# generate random matrices until we have a matrix that is invertible | |
# and has a high enough hadamard ratio. | |
while True: | |
try: | |
V = random_matrix(ZZ, n, x=-d, y=d) | |
V.inverse() | |
if hadamard_ratio(V) < high_ratio_req: | |
# hadamard ratio was too low. try again. | |
continue | |
except ZeroDivisionError: | |
# non-invertible matrix. try again. | |
pass | |
else: | |
break | |
# multiple the public matrix by randomly generated unimodular matrices | |
# until the hadamard ratio reaches a certain "badness" threshold (i.e. | |
# a low enough hadamard ratio) | |
W = V | |
while hadamard_ratio(W) > low_ratio_req: | |
U = random_matrix(ZZ, n, upper_bound=d, algorithm='unimodular') | |
W = U*W | |
self.public_key = W | |
self.private_key = V | |
def encrypt(self, m, r=None, d=5): | |
message_length = len(m) | |
if r == None: | |
#Note: If the basis vectors are too small, r could cause the off-lattice vector (i.e. the ciphertext) | |
# to have several close on-lattice vectors (which will cause decryption false positives) | |
r = random_vector(ZZ, message_length, x=-d, y=d) | |
return (m*self.public_key) + r | |
def decrypt(self, c): | |
mW = babai(c, self.private_key) | |
return mW*self.public_key.inverse() | |
# test_ggh can be used to verify the correctness of the system. | |
def test_ggh(dim=2, n=10): | |
num_correct = 0 | |
for i in range(n): | |
print("%d generating keypair..." % i) | |
k = GGHKeyPair(n=dim) | |
print("%d generating message..." % i) | |
m = random_vector(ZZ, k.n, x=0, y=3) | |
print("%d encrypting message..." % i) | |
c = k.encrypt(m) | |
print("%d decrypting message..." % i) | |
mp = k.decrypt(c) | |
if mp == m: | |
num_correct += 1 | |
print("number correct: %d/%d" % (num_correct, n)) | |
def hadamard_ratio(basis): | |
""" | |
hadamard_ratio returns the hadamard ratio of the supplied basis. the hadamard | |
ratio is a number between 0 and 1 that quantifies the orthogonality of the | |
baiss. The higher the hadamard ratio, the more orthogonal (read: better) the | |
basis. | |
""" | |
if len(basis.rows()) != len(basis.columns()): | |
raise Exception("no non-square matrices allowed") | |
BR = basis.change_ring(RR) | |
d = BR.determinant() | |
ls = [v.norm() for v in BR] | |
l = reduce(lambda x,y: 1.0*x*y, ls) | |
return (d/l)**(1.0/3.0) | |
def babai(w, basis): | |
""" | |
babai will approximate the closest lattice point formed by the supplied basis | |
to the supplied vector w. babai only works well when given a "good" basis. | |
""" | |
W = Matrix(RR, w) | |
BR = basis.change_ring(RR) | |
T = BR.solve_left(W) | |
C = vector(ZZ, [round(t) for t in T[0]]) | |
return C*basis | |
def plot_2d_lattice(v1, v2, xmin=-10, xmax=10, ymin=-10, ymax=10, show_basis_vectors=True): | |
""" | |
plot_2d_lattice will return a sage plot of a lattice with v1 and v2 as basis vectors. | |
(x|y)(min|max) define the maximum and minimum x and y coordinates for the lattice points | |
on the plot. if show_basis_vectors is false, the basis vectors will not be shown on the | |
plot. | |
""" | |
pts = [] | |
# plot all integer multiples of the basis so long as the x and y coordinates | |
# are within (x|y)(min|max). | |
for i in range(xmin, xmax): | |
for j in range(ymin, ymax): | |
pt = i*v1 + j*v2 | |
x,y = pt[0], pt[1] | |
if x < xmin or x > xmax or y < ymin or y > ymax: | |
continue | |
pts.append(pt) | |
the_plot = plot(points(pts)) | |
if show_basis_vectors: | |
the_plot += plot(v1) + plot(v2) | |
return the_plot |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment