Skip to content

Instantly share code, notes, and snippets.

@tl2cents
Created July 21, 2024 04:44
Show Gist options
  • Save tl2cents/63f24e4a2c1005df40c7531c758847ef to your computer and use it in GitHub Desktop.
Save tl2cents/63f24e4a2c1005df40c7531c758847ef to your computer and use it in GitHub Desktop.
Exploits for mat-prod cryptosystem and hitcon 2024 challenge matprod.
from chall import AlternatingMatrixProductCryptosystem
from random import SystemRandom
from hashlib import sha256
from itertools import product
from Crypto.Cipher import AES
from sage.all import matrix, ZZ, Zmod, randint, prod, log, GF, block_matrix, vector
from tqdm import tqdm
# Break AlternatingMatrixProductCryptosystem of https://eprint.iacr.org/2023/1745.pdf
# https://github.com/Neobeo/HackTM2023/blob/main/solve420.sage
def flatter(M):
from subprocess import check_output
from re import findall
M = matrix(ZZ,M)
# compile https://github.com/keeganryan/flatter and put it in $PATH
z = '[[' + ']\n['.join(' '.join(map(str,row)) for row in M) + ']]'
ret = check_output(["flatter"], input=z.encode())
return matrix(M.nrows(), M.ncols(), map(int,findall(b'-?\\d+', ret)))
def gen_alternating_challenge_local(flag=None, paras=(10, 64, 2, 2**553 + 549)):
rand = SystemRandom()
H = sha256()
cry = AlternatingMatrixProductCryptosystem(*paras)
priv, pub = cry.keygen(rand)
msg = cry.randmsg(rand)
M = cry.encrypt(pub, msg)
if cry.decrypt(priv, M) != msg:
raise ValueError("Decryption failed")
H.update(str(msg).encode())
challenge = (pub, M)
if flag is None:
flag = b"flag{" + rand.randbytes(16).hex().encode() + b"}"
cipher = AES.new(H.digest(), AES.MODE_CTR)
enc_flag = cipher.encrypt(flag)
return {"challenge": challenge, "enc_flag": enc_flag, "nonce": cipher.nonce}, msg
def estimate_trace_bound(n, a, k):
# https://eprint.iacr.org/2023/1745.pdf
# The strict upper bound of the trace is too large in Corollary 1.
# I don't find good formula to estimate the general trace bound,
# so I just use the average trace value of the random product matrix as the estimation.
if k == 0:
return n
Ms = []
for i in range(k):
M = matrix(
ZZ,
n,
n,
[randint(0, a) for _ in range(n * n)],
)
Ms.append(M)
pM = prod(Ms)
# avg_pm = sum(pM.list()) // (n * n)
# upper_bound = n * (a*n)**(k-1)
return pM.trace() * 2
def gen_partial_ciphertext(pubkey, i, j, n_samples):
""" Generate `n_samples` partial ciphertexts from index i to j including i.
Args:
pubkey (list): [A^0, A^1], two matrix sets.
i (int): the start index
j (int): the end index (exclusive)
n_samples (int): the number of partial ciphertexts to generate
Returns:
list: the list of partial ciphertexts
"""
# print(f"Generating partial ciphertexts from {i} to {j} with {n_samples} samples")
Cs = []
if 2**(j - i + 1) < n_samples:
for num in range(2**(j - i)):
bits = [int(bit) for bit in bin(num)[2:].zfill(j-i)]
C = prod([pubkey[idx][bit] for idx, bit in enumerate(bits, start=i)]).list()
Cs.append(C)
return Cs
while len(Cs) != n_samples:
randbits = [randint(0, 1) for _ in range(j-i)]
C = prod([pubkey[idx][bit] for idx, bit in enumerate(randbits, start=i)]).list()
if C not in Cs:
Cs.append(C)
return Cs
def modulo_reduction(M, p, verbose=False):
""" Perform LLL reduction on the matrix M with modulo p.
Args:
M (matrix): the matrix to reduce
p (int): the modulo
verbose (bool, optional): whether to print the debug information. Defaults to False.
Returns:
matrix: The reduced matrix
"""
n, m = M.nrows(), M.ncols()
if n < m:
Me = M.change_ring(GF(p)).echelon_form()
delta = Me.ncols() - n
zero_mat = matrix.zero(delta, n)
pI = matrix.identity(delta) * p
L = block_matrix(ZZ, [[Me], [zero_mat.augment(pI)]])
if L.rank() != L.nrows():
L = L.echelon_form(algorithm="pari0", include_zero_rows=False, include_zero_columns=False)
L = L.change_ring(ZZ)
else:
pI = matrix.identity(m) * p
L = block_matrix(ZZ, [[M], [pI]])
if verbose:
print(f"Starts to do LLL reduction with dimensions {L.dimensions()}")
try:
L = flatter(L)
except Exception as e:
print(f"Failed to use flatter: {e}")
print(f"Starts to do sage built-in LLL reduction")
L = L.LLL()
if verbose: print(f"Ends LLL reduction")
return L
def recover_Eij(paras, pubkey, i, j, n_sample = None):
""" Recover the secret Ei*Ej^{-1} from `AlternatingMatrixProductCryptosystem`'s public key.
Args:
paras (tuple): the parameters of the cryptosystem i.e. (n, k, a, p)
pubkey (list): [A^0, A^1], two matrix sets.
n_sample (int): the number of partial ciphertexts to generate, if None, it will be set as estimated value.
"""
assert i > j >= 0, f"Invalid {i = }, {j = }"
(n, k, a, p) = paras
# estimate the trace bound of Aj * A_{j+1} ... A_{i} wher A_i <= alpha
trace_bound = int(estimate_trace_bound(n, a, i-j))
# estimated_t = int(((ZZ(p).nbits() - 1) * n ** 2 / ZZ(p // trace_bound).nbits()))
estimated_t = int(n**2 * log((p/2), p/trace_bound))
# to make LLL algorithm work, we need n_samples > estimated_t
if n_sample is None:
n_sample = min(estimated_t + 32, estimated_t * 2)
print(f"Number of samples used: {n_sample} (also the dimension of lattice)")
print(f"Estimated bit-length of trace bound: {trace_bound.bit_length()}")
partial_ciphertexts = gen_partial_ciphertext(pubkey, j, i, n_sample)
# build lattice
M = matrix(ZZ, partial_ciphertexts).T
L = modulo_reduction(M, p)
L = [v for v in L if v!=0]
traces_vector= L[0]
checks = [num < trace_bound for num in traces_vector]
traces_vector_bits = [int(num).bit_length() for num in traces_vector]
print(f"Average bit-length of recovered traces: {sum(traces_vector_bits)//len(traces_vector_bits)}")
assert all(checks), "Failed to recover Eij"
sol = M.change_ring(GF(p)).solve_left(traces_vector)
return sol, trace_bound
def balanced_mod(x, p):
x = ZZ(x) % p
return x if x <= p // 2 else x - p
def break_alternating_cryptosystem(paras, pubkey, C, step_size=16):
""" Break the direct cryptosystem, decrypting the ciphertext C.
Args:
paras (tuple): the parameters of the cryptosystem i.e. (n, k, a, p)
pubkey (list): list of matrix: bar A
C (matrix): the ciphertext: C = prod(sigma, A)
step_size (int): the step size of recovering bits, default is 16
Returns:
list: the decrypted message bits
"""
assert step_size <= 24, "The step size is too large and may use a lot of memory and time"
(n, k, a, p) = paras
recovered_bits = []
# pubkey_inv = [(A0 ** -1, A1 ** -1) for A0, A1 in pubkey]
for i in range(0, k, step_size):
print(f"Recovering bits from {i} to {i + step_size}")
# the special case of the last bits
if i + step_size >= k:
# using direct brute-force to find the last bits
step_size = k - i
table1 = [(prod([pubkey[idx][bit] for idx, bit in enumerate(partial_sol, start=i)]) ** -1, partial_sol)
for partial_sol in product([0, 1], repeat=step_size//2)]
# the second half of prod
table2 = [(prod([pubkey[idx][bit] for idx, bit in enumerate(partial_sol, start=i + step_size//2)]) ** -1, partial_sol)
for partial_sol in product([0, 1], repeat=step_size - step_size//2)]
for C1, partial_sol1 in tqdm(table1, leave=False):
for C2, partial_sol2 in table2:
partial_sol = partial_sol1 + partial_sol2
partial_ciphertext = C2 * C1 * C
if partial_ciphertext == 1:
print(f"Found partial solution of {i} to {i + step_size}")
recovered_bits.extend(partial_sol)
print(f"The recovered bits: {recovered_bits}")
return recovered_bits
assert False, "Failed to find the last bits"
# The general case
Eki, trace_bound = recover_Eij(paras, pubkey, k, i + step_size)
# find the solutions
# bf_space = product([0, 1], repeat=step_size)
# two step prod to speed up
# the first half of prod
table1 = [(prod([pubkey[idx][bit] for idx, bit in enumerate(partial_sol, start=i)]) ** -1, partial_sol)
for partial_sol in product([0, 1], repeat=step_size//2)]
# the second half of prod
table2 = [(prod([pubkey[idx][bit] for idx, bit in enumerate(partial_sol, start=i + step_size//2)]) ** -1, partial_sol)
for partial_sol in product([0, 1], repeat=step_size - step_size//2)]
find_sol = False
for C1, partial_sol1 in tqdm(table1, leave=False):
for C2, partial_sol2 in table2:
partial_sol = partial_sol1 + partial_sol2
partial_ciphertext = C2 * C1 * C
trace_mA = abs(balanced_mod(Eki * vector(partial_ciphertext.list()), p))
if trace_mA <= trace_bound:
print(f"Trace_mA : {int(trace_mA).bit_length()} bits, estimated trace_bound: {trace_bound.bit_length()} bits")
print(f"Found partial solution of {i} to {i + step_size}")
# right solution
C = partial_ciphertext
recovered_bits.extend(partial_sol)
print(f"Current recovered bits: {recovered_bits}")
find_sol = True
break
if find_sol:
break
print()
return recovered_bits
def test_break_alternating_cryptosystem():
# a small size one
# (n, k, a, p) = 10, 64, 2, 2**553 + 549
(n, k, a, p) = 10, 128, 2, 2**553 + 549
paras = (n, k, a, p)
data, msg = gen_alternating_challenge_local(None, paras)
pubkey, C = data["challenge"]
# Eij = recover_Eij(paras, pubkey, 48, 0)
m_bits = break_alternating_cryptosystem(paras, pubkey, C)
msg_ = int("".join(map(str, m_bits[::-1])), 2)
print(f"Original message: {msg}")
print(f"Recovered message: {msg_}")
assert msg == msg_, "Failed to recover the message"
if __name__ == "__main__":
test_break_alternating_cryptosystem()
from chall import alternating, direct, DirectMatrixProductCryptosystem
from random import Random, SystemRandom
from hashlib import sha256
from Crypto.Cipher import AES
from sage.all import load, save, matrix, ZZ, Zmod
# Break DirectMatrixProductCryptosystem of https://eprint.iacr.org/2023/1745.pdf
def gen_direct_challenge_local(paras=(10, 35, 2, 2**302 + 307)):
rand = SystemRandom()
cry = DirectMatrixProductCryptosystem(*paras)
priv, pub = cry.keygen(rand)
msg = cry.randmsg(rand)
M = cry.encrypt(pub, msg)
if cry.decrypt(priv, M) != msg:
raise ValueError("Decryption failed")
challenge = (pub, M)
return challenge, msg
def dfs_search_message(C, pubkey):
pubkey_inv = [A ** (-1) for A in pubkey]
def dfs_search(current_c, current_path):
# print(f"current_path: {current_path}")
if len(current_path) == len(pubkey_inv) - 1:
if current_c in pubkey:
# print(f"[+] possible perm: {current_path}")
yield current_path + [pubkey.index(current_c)]
for i in range(len(pubkey_inv)):
if i not in current_path:
try_mat = pubkey_inv[i]
if (try_mat * current_c).trace() <= current_c.trace():
yield from dfs_search(try_mat * current_c, current_path + [i])
yield from dfs_search(C, [])
def break_direct_cryptosystem(paras, pubkey, C):
""" Break the direct cryptosystem, decrypting the ciphertext C.
Args:
paras (tuple): the parameters of the cryptosystem i.e. (n, k, a, p)
pubkey (list): list of matrix: bar A
C (matrix): the ciphertext: C = prod(sigma, A)
Returns:
list: the decrypted permutation sigma
"""
(n, k, a, p) = paras
for perm in dfs_search_message(C, pubkey):
if perm:
print(f"Found perm {perm}")
direct_cipher = DirectMatrixProductCryptosystem(n, k, a, p)
return direct_cipher.decode(perm)
def test_break_direct_cryptosystem():
(n, k, a, p) = 10, 35, 2, 2**302 + 307
direct_paras = (n, k, a, p)
(pubkey, C), msg = gen_direct_challenge_local(direct_paras)
sol = break_direct_cryptosystem(direct_paras, pubkey, C)
assert sol == msg, f"Failed: {sol} != {msg}"
print("Passed test_break_direct_cryptosystem")
if __name__ == "__main__":
test_break_direct_cryptosystem()
from sage.all import *
from random import Random, SystemRandom
from hashlib import sha256
from Crypto.Cipher import AES
# https://eprint.iacr.org/2023/1745.pdf
class BaseMatrixProductCryptosystem:
def __init__(self, n: int, k: int, a: int, p: int):
self.n = n
self.k = k
self.a = a
self.p = p
self.F = GF(p)
def rand_drawf(self, rand: Random, check=True, ring=None):
if ring is None:
ring = self.F
while True:
M = matrix(
ZZ,
self.n,
self.n,
[rand.randint(0, self.a) for _ in range(self.n * self.n)],
)
det = M.det()
if not check or det % self.p != 0:
return M.change_ring(ring)
def rand_elf(self, rand: Random, check=True, ring=None):
if ring is None:
ring = self.F
while True:
M = matrix(
ZZ,
self.n,
self.n,
[rand.randrange(0, self.p) for _ in range(self.n * self.n)],
)
det = M.det()
if not check or det % self.p != 0:
return M.change_ring(ring)
class DirectMatrixProductCryptosystem(BaseMatrixProductCryptosystem):
def keygen(self, rand: Random):
As = [self.rand_drawf(rand) for _ in range(self.k)]
D = self.rand_drawf(rand)
E = self.rand_elf(rand)
Ei = E.inverse()
priv = (As, D, E, Ei)
pub = [E * A * D * Ei for A in As]
return priv, pub
def encrypt_perm(self, pub, perm):
M = pub[perm[0]]
for i in range(1, self.k):
M = M * pub[perm[i]]
return M
def decompose(self, M, As, D, L):
if len(L) == len(As):
return
if len(As) - 1 == len(L):
idx = next(iter(set(range(len(As))) - set(L)))
if As[idx] == M:
return L + [idx]
return
threshold = self.n * (self.n - 1)
for i, A in enumerate(As):
if i in L:
continue
try:
Mp = D.solve_right(A.solve_right(M))
except ValueError:
continue
# if all(int(x) <= int(y) for x, y in zip(Mp.list(), M.list())):
# this is a bit different from the paper
# because the provided decryption algorithm in the paper often fail to decrypt till the end
smaller_cnt = len(
[1 for x, y in zip(Mp.list(), M.list()) if int(x) <= int(y)]
)
if smaller_cnt >= threshold:
ret = self.decompose(Mp, As, D, L + [i])
if ret is not None:
return ret
def decrypt_perm(self, priv, M):
As, D, E, Ei = priv
R = Ei * M * E * ~D
return self.decompose(R, As, D, [])
def encode(self, m):
P = Permutations(self.k)
if m < 0 or m > P.cardinality():
raise ValueError("Invalid message")
return [x - 1 for x in P.unrank(m)]
def decode(self, p):
P = Permutations(self.k)
return P.rank([x + 1 for x in p])
def encrypt(self, pub, m):
return self.encrypt_perm(pub, self.encode(m))
def decrypt(self, priv, M):
ret = self.decrypt_perm(priv, M)
if ret is not None:
return self.decode(ret)
def randmsg(self, rand: Random):
return rand.randrange(0, factorial(self.k))
class AlternatingMatrixProductCryptosystem(BaseMatrixProductCryptosystem):
def rand_pair_drawf(self, rand: Random, lookup: dict):
while True:
A = self.rand_drawf(
rand, check=False, ring=ZZ
) # computing determinant in ZZ is so much faster than in GF(p) ...
d = A.det()
if d == 0:
continue
if d in lookup and lookup[d] != A:
AA = lookup.pop(d)
return A.change_ring(self.F), AA.change_ring(self.F)
lookup[d] = A
def keygen(self, rand: Random):
lookup = {}
As = [self.rand_pair_drawf(rand, lookup) for _ in range(self.k)]
Es = [self.rand_elf(rand) for _ in range(self.k + 1)]
ABars = []
for i in range(self.k):
cur = []
for b in (0, 1):
cur.append(Es[i] * As[i][b] * ~Es[i + 1])
ABars.append(cur)
priv = (Es[0], Es[self.k], As)
pub = ABars
return priv, pub
def encrypt_bits(self, pub, bits):
M = pub[0][bits[0]]
for i in range(1, self.k):
M = M * pub[i][bits[i]]
return M
def decompose(self, M, As):
threshold = self.n * (self.n - 1)
bits = []
for i in range(self.k):
for b in (0, 1):
A = As[i][b]
try:
Mp = A.solve_right(M)
except ValueError:
continue
# if all(int(x) <= int(y) for x, y in zip(Mp.list(), M.list())):
# this is a bit different from the paper
# because the provided decryption algorithm in the paper often fail to decrypt till the end
smaller_cnt = len(
[1 for x, y in zip(Mp.list(), M.list()) if int(x) <= int(y)]
)
if smaller_cnt >= threshold:
bits.append(b)
M = Mp
break
else:
return
return bits
def decrypt_bits(self, priv, M):
E0, Ek, As = priv
R = ~E0 * M * Ek
return self.decompose(R, As)
def encode(self, m):
if m < 0 or m > 2**self.k:
raise ValueError("Invalid message")
return [(m >> i) & 1 for i in range(self.k)]
def decode(self, p):
return sum(x << i for i, x in enumerate(p))
def encrypt(self, pub, m):
return self.encrypt_bits(pub, self.encode(m))
def decrypt(self, priv, M):
ret = self.decrypt_bits(priv, M)
if ret is not None:
return self.decode(ret)
def randmsg(self, rand: Random):
return rand.getrandbits(self.k)
direct = DirectMatrixProductCryptosystem(
10, 35, 2, 2**302 + 307
) # Recommended size, 128-bit security
alternating = AlternatingMatrixProductCryptosystem(
10, 128, 2, 2**553 + 549
) # Recommended size, 128-bit security
if __name__ == "__main__":
rand = SystemRandom()
H = sha256()
challenges = []
for cry in (direct, alternating):
priv, pub = cry.keygen(rand)
msg = cry.randmsg(rand)
M = cry.encrypt(pub, msg)
if cry.decrypt(priv, M) != msg:
raise ValueError("Decryption failed")
H.update(str(msg).encode())
challenges.append((pub, M))
with open("flag.txt", "rb") as f:
flag = f.read().strip()
cipher = AES.new(H.digest(), AES.MODE_CTR)
enc_flag = cipher.encrypt(flag)
save(
{"challenges": challenges, "enc_flag": enc_flag, "nonce": cipher.nonce},
"output.sobj",
)
# additional note: this script is generated by running the script using SageMath 10.3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment