Skip to content

Instantly share code, notes, and snippets.

@lmyyao
Created November 27, 2023 11:11
Show Gist options
  • Save lmyyao/c33e8a9537ef700f06f52794ba74bbd0 to your computer and use it in GitHub Desktop.
Save lmyyao/c33e8a9537ef700f06f52794ba74bbd0 to your computer and use it in GitHub Desktop.
kyber kem demo
import numpy as np
P = next_prime(2^256)
N = 64
k = 16
BOUND = 32
HALF_P = int(P/2+0.5)
R.<x> = PolynomialRing(GF(P), 'x')
T.<z> = R.quotient(x^N+1)
def gen_small_poly_vector(size, bound=BOUND, check=True):
while True:
v = [T(list(np.random.randint(-bound, bound, N))) for _ in range(size)]
if check and (len(set(v)) != len(v)):
continue
else:
if size==1:
return v[0]
return vector(v)
def keygen():
A = random_matrix(T, k, k)
s = gen_small_poly_vector(k)
e = gen_small_poly_vector(k)
return s, (A, A*s+e)
def encrypt(pk, m):
A, t = pk
r = gen_small_poly_vector(k)
e1 = gen_small_poly_vector(k)
e2 = gen_small_poly_vector(1)
u = A.T * r + e1
v = t * r + e2 + T(m) * HALF_P
return u,v
def decrypt(sk, m):
u, v = m
m = v - sk*u
return m
def round(val, center, bound):
dist_center = np.abs(center - val)
dist_bound = min(val, bound - val)
return 1 if dist_center < dist_bound else 0
def check_same(decry_m, m):
dm =[int(i) for i in decry_m.list()]
dm = list(map(lambda x: round(x, HALF_P, P), dm))
return dm == m
if __name__ == "__main__":
count = 0
for i in range(100):
print(i, end=",")
m = list(np.random.randint(0,2, N))
sk, pk = keygen()
me = encrypt(pk, m)
dm = decrypt(sk, me)
if check_same(dm,m):
count += 1
print(true)
print("count: ", count)
@lmyyao
Copy link
Author

lmyyao commented Nov 27, 2023

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment