Created
July 21, 2024 04:44
-
-
Save tl2cents/63f24e4a2c1005df40c7531c758847ef to your computer and use it in GitHub Desktop.
Exploits for mat-prod cryptosystem and hitcon 2024 challenge matprod.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from chall import AlternatingMatrixProductCryptosystem | |
from random import SystemRandom | |
from hashlib import sha256 | |
from itertools import product | |
from Crypto.Cipher import AES | |
from sage.all import matrix, ZZ, Zmod, randint, prod, log, GF, block_matrix, vector | |
from tqdm import tqdm | |
# Break AlternatingMatrixProductCryptosystem of https://eprint.iacr.org/2023/1745.pdf | |
# https://github.com/Neobeo/HackTM2023/blob/main/solve420.sage | |
def flatter(M): | |
from subprocess import check_output | |
from re import findall | |
M = matrix(ZZ,M) | |
# compile https://github.com/keeganryan/flatter and put it in $PATH | |
z = '[[' + ']\n['.join(' '.join(map(str,row)) for row in M) + ']]' | |
ret = check_output(["flatter"], input=z.encode()) | |
return matrix(M.nrows(), M.ncols(), map(int,findall(b'-?\\d+', ret))) | |
def gen_alternating_challenge_local(flag=None, paras=(10, 64, 2, 2**553 + 549)): | |
rand = SystemRandom() | |
H = sha256() | |
cry = AlternatingMatrixProductCryptosystem(*paras) | |
priv, pub = cry.keygen(rand) | |
msg = cry.randmsg(rand) | |
M = cry.encrypt(pub, msg) | |
if cry.decrypt(priv, M) != msg: | |
raise ValueError("Decryption failed") | |
H.update(str(msg).encode()) | |
challenge = (pub, M) | |
if flag is None: | |
flag = b"flag{" + rand.randbytes(16).hex().encode() + b"}" | |
cipher = AES.new(H.digest(), AES.MODE_CTR) | |
enc_flag = cipher.encrypt(flag) | |
return {"challenge": challenge, "enc_flag": enc_flag, "nonce": cipher.nonce}, msg | |
def estimate_trace_bound(n, a, k): | |
# https://eprint.iacr.org/2023/1745.pdf | |
# The strict upper bound of the trace is too large in Corollary 1. | |
# I don't find good formula to estimate the general trace bound, | |
# so I just use the average trace value of the random product matrix as the estimation. | |
if k == 0: | |
return n | |
Ms = [] | |
for i in range(k): | |
M = matrix( | |
ZZ, | |
n, | |
n, | |
[randint(0, a) for _ in range(n * n)], | |
) | |
Ms.append(M) | |
pM = prod(Ms) | |
# avg_pm = sum(pM.list()) // (n * n) | |
# upper_bound = n * (a*n)**(k-1) | |
return pM.trace() * 2 | |
def gen_partial_ciphertext(pubkey, i, j, n_samples): | |
""" Generate `n_samples` partial ciphertexts from index i to j including i. | |
Args: | |
pubkey (list): [A^0, A^1], two matrix sets. | |
i (int): the start index | |
j (int): the end index (exclusive) | |
n_samples (int): the number of partial ciphertexts to generate | |
Returns: | |
list: the list of partial ciphertexts | |
""" | |
# print(f"Generating partial ciphertexts from {i} to {j} with {n_samples} samples") | |
Cs = [] | |
if 2**(j - i + 1) < n_samples: | |
for num in range(2**(j - i)): | |
bits = [int(bit) for bit in bin(num)[2:].zfill(j-i)] | |
C = prod([pubkey[idx][bit] for idx, bit in enumerate(bits, start=i)]).list() | |
Cs.append(C) | |
return Cs | |
while len(Cs) != n_samples: | |
randbits = [randint(0, 1) for _ in range(j-i)] | |
C = prod([pubkey[idx][bit] for idx, bit in enumerate(randbits, start=i)]).list() | |
if C not in Cs: | |
Cs.append(C) | |
return Cs | |
def modulo_reduction(M, p, verbose=False): | |
""" Perform LLL reduction on the matrix M with modulo p. | |
Args: | |
M (matrix): the matrix to reduce | |
p (int): the modulo | |
verbose (bool, optional): whether to print the debug information. Defaults to False. | |
Returns: | |
matrix: The reduced matrix | |
""" | |
n, m = M.nrows(), M.ncols() | |
if n < m: | |
Me = M.change_ring(GF(p)).echelon_form() | |
delta = Me.ncols() - n | |
zero_mat = matrix.zero(delta, n) | |
pI = matrix.identity(delta) * p | |
L = block_matrix(ZZ, [[Me], [zero_mat.augment(pI)]]) | |
if L.rank() != L.nrows(): | |
L = L.echelon_form(algorithm="pari0", include_zero_rows=False, include_zero_columns=False) | |
L = L.change_ring(ZZ) | |
else: | |
pI = matrix.identity(m) * p | |
L = block_matrix(ZZ, [[M], [pI]]) | |
if verbose: | |
print(f"Starts to do LLL reduction with dimensions {L.dimensions()}") | |
try: | |
L = flatter(L) | |
except Exception as e: | |
print(f"Failed to use flatter: {e}") | |
print(f"Starts to do sage built-in LLL reduction") | |
L = L.LLL() | |
if verbose: print(f"Ends LLL reduction") | |
return L | |
def recover_Eij(paras, pubkey, i, j, n_sample = None): | |
""" Recover the secret Ei*Ej^{-1} from `AlternatingMatrixProductCryptosystem`'s public key. | |
Args: | |
paras (tuple): the parameters of the cryptosystem i.e. (n, k, a, p) | |
pubkey (list): [A^0, A^1], two matrix sets. | |
n_sample (int): the number of partial ciphertexts to generate, if None, it will be set as estimated value. | |
""" | |
assert i > j >= 0, f"Invalid {i = }, {j = }" | |
(n, k, a, p) = paras | |
# estimate the trace bound of Aj * A_{j+1} ... A_{i} wher A_i <= alpha | |
trace_bound = int(estimate_trace_bound(n, a, i-j)) | |
# estimated_t = int(((ZZ(p).nbits() - 1) * n ** 2 / ZZ(p // trace_bound).nbits())) | |
estimated_t = int(n**2 * log((p/2), p/trace_bound)) | |
# to make LLL algorithm work, we need n_samples > estimated_t | |
if n_sample is None: | |
n_sample = min(estimated_t + 32, estimated_t * 2) | |
print(f"Number of samples used: {n_sample} (also the dimension of lattice)") | |
print(f"Estimated bit-length of trace bound: {trace_bound.bit_length()}") | |
partial_ciphertexts = gen_partial_ciphertext(pubkey, j, i, n_sample) | |
# build lattice | |
M = matrix(ZZ, partial_ciphertexts).T | |
L = modulo_reduction(M, p) | |
L = [v for v in L if v!=0] | |
traces_vector= L[0] | |
checks = [num < trace_bound for num in traces_vector] | |
traces_vector_bits = [int(num).bit_length() for num in traces_vector] | |
print(f"Average bit-length of recovered traces: {sum(traces_vector_bits)//len(traces_vector_bits)}") | |
assert all(checks), "Failed to recover Eij" | |
sol = M.change_ring(GF(p)).solve_left(traces_vector) | |
return sol, trace_bound | |
def balanced_mod(x, p): | |
x = ZZ(x) % p | |
return x if x <= p // 2 else x - p | |
def break_alternating_cryptosystem(paras, pubkey, C, step_size=16): | |
""" Break the direct cryptosystem, decrypting the ciphertext C. | |
Args: | |
paras (tuple): the parameters of the cryptosystem i.e. (n, k, a, p) | |
pubkey (list): list of matrix: bar A | |
C (matrix): the ciphertext: C = prod(sigma, A) | |
step_size (int): the step size of recovering bits, default is 16 | |
Returns: | |
list: the decrypted message bits | |
""" | |
assert step_size <= 24, "The step size is too large and may use a lot of memory and time" | |
(n, k, a, p) = paras | |
recovered_bits = [] | |
# pubkey_inv = [(A0 ** -1, A1 ** -1) for A0, A1 in pubkey] | |
for i in range(0, k, step_size): | |
print(f"Recovering bits from {i} to {i + step_size}") | |
# the special case of the last bits | |
if i + step_size >= k: | |
# using direct brute-force to find the last bits | |
step_size = k - i | |
table1 = [(prod([pubkey[idx][bit] for idx, bit in enumerate(partial_sol, start=i)]) ** -1, partial_sol) | |
for partial_sol in product([0, 1], repeat=step_size//2)] | |
# the second half of prod | |
table2 = [(prod([pubkey[idx][bit] for idx, bit in enumerate(partial_sol, start=i + step_size//2)]) ** -1, partial_sol) | |
for partial_sol in product([0, 1], repeat=step_size - step_size//2)] | |
for C1, partial_sol1 in tqdm(table1, leave=False): | |
for C2, partial_sol2 in table2: | |
partial_sol = partial_sol1 + partial_sol2 | |
partial_ciphertext = C2 * C1 * C | |
if partial_ciphertext == 1: | |
print(f"Found partial solution of {i} to {i + step_size}") | |
recovered_bits.extend(partial_sol) | |
print(f"The recovered bits: {recovered_bits}") | |
return recovered_bits | |
assert False, "Failed to find the last bits" | |
# The general case | |
Eki, trace_bound = recover_Eij(paras, pubkey, k, i + step_size) | |
# find the solutions | |
# bf_space = product([0, 1], repeat=step_size) | |
# two step prod to speed up | |
# the first half of prod | |
table1 = [(prod([pubkey[idx][bit] for idx, bit in enumerate(partial_sol, start=i)]) ** -1, partial_sol) | |
for partial_sol in product([0, 1], repeat=step_size//2)] | |
# the second half of prod | |
table2 = [(prod([pubkey[idx][bit] for idx, bit in enumerate(partial_sol, start=i + step_size//2)]) ** -1, partial_sol) | |
for partial_sol in product([0, 1], repeat=step_size - step_size//2)] | |
find_sol = False | |
for C1, partial_sol1 in tqdm(table1, leave=False): | |
for C2, partial_sol2 in table2: | |
partial_sol = partial_sol1 + partial_sol2 | |
partial_ciphertext = C2 * C1 * C | |
trace_mA = abs(balanced_mod(Eki * vector(partial_ciphertext.list()), p)) | |
if trace_mA <= trace_bound: | |
print(f"Trace_mA : {int(trace_mA).bit_length()} bits, estimated trace_bound: {trace_bound.bit_length()} bits") | |
print(f"Found partial solution of {i} to {i + step_size}") | |
# right solution | |
C = partial_ciphertext | |
recovered_bits.extend(partial_sol) | |
print(f"Current recovered bits: {recovered_bits}") | |
find_sol = True | |
break | |
if find_sol: | |
break | |
print() | |
return recovered_bits | |
def test_break_alternating_cryptosystem(): | |
# a small size one | |
# (n, k, a, p) = 10, 64, 2, 2**553 + 549 | |
(n, k, a, p) = 10, 128, 2, 2**553 + 549 | |
paras = (n, k, a, p) | |
data, msg = gen_alternating_challenge_local(None, paras) | |
pubkey, C = data["challenge"] | |
# Eij = recover_Eij(paras, pubkey, 48, 0) | |
m_bits = break_alternating_cryptosystem(paras, pubkey, C) | |
msg_ = int("".join(map(str, m_bits[::-1])), 2) | |
print(f"Original message: {msg}") | |
print(f"Recovered message: {msg_}") | |
assert msg == msg_, "Failed to recover the message" | |
if __name__ == "__main__": | |
test_break_alternating_cryptosystem() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from chall import alternating, direct, DirectMatrixProductCryptosystem | |
from random import Random, SystemRandom | |
from hashlib import sha256 | |
from Crypto.Cipher import AES | |
from sage.all import load, save, matrix, ZZ, Zmod | |
# Break DirectMatrixProductCryptosystem of https://eprint.iacr.org/2023/1745.pdf | |
def gen_direct_challenge_local(paras=(10, 35, 2, 2**302 + 307)): | |
rand = SystemRandom() | |
cry = DirectMatrixProductCryptosystem(*paras) | |
priv, pub = cry.keygen(rand) | |
msg = cry.randmsg(rand) | |
M = cry.encrypt(pub, msg) | |
if cry.decrypt(priv, M) != msg: | |
raise ValueError("Decryption failed") | |
challenge = (pub, M) | |
return challenge, msg | |
def dfs_search_message(C, pubkey): | |
pubkey_inv = [A ** (-1) for A in pubkey] | |
def dfs_search(current_c, current_path): | |
# print(f"current_path: {current_path}") | |
if len(current_path) == len(pubkey_inv) - 1: | |
if current_c in pubkey: | |
# print(f"[+] possible perm: {current_path}") | |
yield current_path + [pubkey.index(current_c)] | |
for i in range(len(pubkey_inv)): | |
if i not in current_path: | |
try_mat = pubkey_inv[i] | |
if (try_mat * current_c).trace() <= current_c.trace(): | |
yield from dfs_search(try_mat * current_c, current_path + [i]) | |
yield from dfs_search(C, []) | |
def break_direct_cryptosystem(paras, pubkey, C): | |
""" Break the direct cryptosystem, decrypting the ciphertext C. | |
Args: | |
paras (tuple): the parameters of the cryptosystem i.e. (n, k, a, p) | |
pubkey (list): list of matrix: bar A | |
C (matrix): the ciphertext: C = prod(sigma, A) | |
Returns: | |
list: the decrypted permutation sigma | |
""" | |
(n, k, a, p) = paras | |
for perm in dfs_search_message(C, pubkey): | |
if perm: | |
print(f"Found perm {perm}") | |
direct_cipher = DirectMatrixProductCryptosystem(n, k, a, p) | |
return direct_cipher.decode(perm) | |
def test_break_direct_cryptosystem(): | |
(n, k, a, p) = 10, 35, 2, 2**302 + 307 | |
direct_paras = (n, k, a, p) | |
(pubkey, C), msg = gen_direct_challenge_local(direct_paras) | |
sol = break_direct_cryptosystem(direct_paras, pubkey, C) | |
assert sol == msg, f"Failed: {sol} != {msg}" | |
print("Passed test_break_direct_cryptosystem") | |
if __name__ == "__main__": | |
test_break_direct_cryptosystem() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from sage.all import * | |
from random import Random, SystemRandom | |
from hashlib import sha256 | |
from Crypto.Cipher import AES | |
# https://eprint.iacr.org/2023/1745.pdf | |
class BaseMatrixProductCryptosystem: | |
def __init__(self, n: int, k: int, a: int, p: int): | |
self.n = n | |
self.k = k | |
self.a = a | |
self.p = p | |
self.F = GF(p) | |
def rand_drawf(self, rand: Random, check=True, ring=None): | |
if ring is None: | |
ring = self.F | |
while True: | |
M = matrix( | |
ZZ, | |
self.n, | |
self.n, | |
[rand.randint(0, self.a) for _ in range(self.n * self.n)], | |
) | |
det = M.det() | |
if not check or det % self.p != 0: | |
return M.change_ring(ring) | |
def rand_elf(self, rand: Random, check=True, ring=None): | |
if ring is None: | |
ring = self.F | |
while True: | |
M = matrix( | |
ZZ, | |
self.n, | |
self.n, | |
[rand.randrange(0, self.p) for _ in range(self.n * self.n)], | |
) | |
det = M.det() | |
if not check or det % self.p != 0: | |
return M.change_ring(ring) | |
class DirectMatrixProductCryptosystem(BaseMatrixProductCryptosystem): | |
def keygen(self, rand: Random): | |
As = [self.rand_drawf(rand) for _ in range(self.k)] | |
D = self.rand_drawf(rand) | |
E = self.rand_elf(rand) | |
Ei = E.inverse() | |
priv = (As, D, E, Ei) | |
pub = [E * A * D * Ei for A in As] | |
return priv, pub | |
def encrypt_perm(self, pub, perm): | |
M = pub[perm[0]] | |
for i in range(1, self.k): | |
M = M * pub[perm[i]] | |
return M | |
def decompose(self, M, As, D, L): | |
if len(L) == len(As): | |
return | |
if len(As) - 1 == len(L): | |
idx = next(iter(set(range(len(As))) - set(L))) | |
if As[idx] == M: | |
return L + [idx] | |
return | |
threshold = self.n * (self.n - 1) | |
for i, A in enumerate(As): | |
if i in L: | |
continue | |
try: | |
Mp = D.solve_right(A.solve_right(M)) | |
except ValueError: | |
continue | |
# if all(int(x) <= int(y) for x, y in zip(Mp.list(), M.list())): | |
# this is a bit different from the paper | |
# because the provided decryption algorithm in the paper often fail to decrypt till the end | |
smaller_cnt = len( | |
[1 for x, y in zip(Mp.list(), M.list()) if int(x) <= int(y)] | |
) | |
if smaller_cnt >= threshold: | |
ret = self.decompose(Mp, As, D, L + [i]) | |
if ret is not None: | |
return ret | |
def decrypt_perm(self, priv, M): | |
As, D, E, Ei = priv | |
R = Ei * M * E * ~D | |
return self.decompose(R, As, D, []) | |
def encode(self, m): | |
P = Permutations(self.k) | |
if m < 0 or m > P.cardinality(): | |
raise ValueError("Invalid message") | |
return [x - 1 for x in P.unrank(m)] | |
def decode(self, p): | |
P = Permutations(self.k) | |
return P.rank([x + 1 for x in p]) | |
def encrypt(self, pub, m): | |
return self.encrypt_perm(pub, self.encode(m)) | |
def decrypt(self, priv, M): | |
ret = self.decrypt_perm(priv, M) | |
if ret is not None: | |
return self.decode(ret) | |
def randmsg(self, rand: Random): | |
return rand.randrange(0, factorial(self.k)) | |
class AlternatingMatrixProductCryptosystem(BaseMatrixProductCryptosystem): | |
def rand_pair_drawf(self, rand: Random, lookup: dict): | |
while True: | |
A = self.rand_drawf( | |
rand, check=False, ring=ZZ | |
) # computing determinant in ZZ is so much faster than in GF(p) ... | |
d = A.det() | |
if d == 0: | |
continue | |
if d in lookup and lookup[d] != A: | |
AA = lookup.pop(d) | |
return A.change_ring(self.F), AA.change_ring(self.F) | |
lookup[d] = A | |
def keygen(self, rand: Random): | |
lookup = {} | |
As = [self.rand_pair_drawf(rand, lookup) for _ in range(self.k)] | |
Es = [self.rand_elf(rand) for _ in range(self.k + 1)] | |
ABars = [] | |
for i in range(self.k): | |
cur = [] | |
for b in (0, 1): | |
cur.append(Es[i] * As[i][b] * ~Es[i + 1]) | |
ABars.append(cur) | |
priv = (Es[0], Es[self.k], As) | |
pub = ABars | |
return priv, pub | |
def encrypt_bits(self, pub, bits): | |
M = pub[0][bits[0]] | |
for i in range(1, self.k): | |
M = M * pub[i][bits[i]] | |
return M | |
def decompose(self, M, As): | |
threshold = self.n * (self.n - 1) | |
bits = [] | |
for i in range(self.k): | |
for b in (0, 1): | |
A = As[i][b] | |
try: | |
Mp = A.solve_right(M) | |
except ValueError: | |
continue | |
# if all(int(x) <= int(y) for x, y in zip(Mp.list(), M.list())): | |
# this is a bit different from the paper | |
# because the provided decryption algorithm in the paper often fail to decrypt till the end | |
smaller_cnt = len( | |
[1 for x, y in zip(Mp.list(), M.list()) if int(x) <= int(y)] | |
) | |
if smaller_cnt >= threshold: | |
bits.append(b) | |
M = Mp | |
break | |
else: | |
return | |
return bits | |
def decrypt_bits(self, priv, M): | |
E0, Ek, As = priv | |
R = ~E0 * M * Ek | |
return self.decompose(R, As) | |
def encode(self, m): | |
if m < 0 or m > 2**self.k: | |
raise ValueError("Invalid message") | |
return [(m >> i) & 1 for i in range(self.k)] | |
def decode(self, p): | |
return sum(x << i for i, x in enumerate(p)) | |
def encrypt(self, pub, m): | |
return self.encrypt_bits(pub, self.encode(m)) | |
def decrypt(self, priv, M): | |
ret = self.decrypt_bits(priv, M) | |
if ret is not None: | |
return self.decode(ret) | |
def randmsg(self, rand: Random): | |
return rand.getrandbits(self.k) | |
direct = DirectMatrixProductCryptosystem( | |
10, 35, 2, 2**302 + 307 | |
) # Recommended size, 128-bit security | |
alternating = AlternatingMatrixProductCryptosystem( | |
10, 128, 2, 2**553 + 549 | |
) # Recommended size, 128-bit security | |
if __name__ == "__main__": | |
rand = SystemRandom() | |
H = sha256() | |
challenges = [] | |
for cry in (direct, alternating): | |
priv, pub = cry.keygen(rand) | |
msg = cry.randmsg(rand) | |
M = cry.encrypt(pub, msg) | |
if cry.decrypt(priv, M) != msg: | |
raise ValueError("Decryption failed") | |
H.update(str(msg).encode()) | |
challenges.append((pub, M)) | |
with open("flag.txt", "rb") as f: | |
flag = f.read().strip() | |
cipher = AES.new(H.digest(), AES.MODE_CTR) | |
enc_flag = cipher.encrypt(flag) | |
save( | |
{"challenges": challenges, "enc_flag": enc_flag, "nonce": cipher.nonce}, | |
"output.sobj", | |
) | |
# additional note: this script is generated by running the script using SageMath 10.3 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment