Skip to content

Instantly share code, notes, and snippets.

@defund
Last active April 5, 2022 22:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save defund/02d24a4b8a75849e2e9456d7c398cc0d to your computer and use it in GitHub Desktop.
Save defund/02d24a4b8a75849e2e9456d7c398cc0d to your computer and use it in GitHub Desktop.
hxp CTF 2021
import collections
from Crypto.Hash import SHA512
import itertools
import json
from multiprocessing import Pool
from tqdm import tqdm
proof.all(False)
ls = list(prime_range(3,117))
p = 4 * prod(ls) - 1
base = bytes((int(p).bit_length() + 7) // 8)
R.<t> = GF(p)[]
def montgomery_coefficient(E):
a,b = E.short_weierstrass_model().a_invariants()[-2:]
r, = (t**3 + a*t + b).roots(multiplicities=False)
s = sqrt(3*r**2 + a)
return -3 * (-1)**is_square(s) * r / s
def csidh(pub, priv):
E = EllipticCurve(GF(p), [0, pub, 0, 1, 0])
for es in ([max(0,+e) for e in priv], [max(0,-e) for e in priv]):
while any(es):
P = E.random_point()
k = prod(l for l,e in zip(ls,es) if e)
P *= (p+1) // k
for i,(l,e) in enumerate(zip(ls,es)):
if not e: continue
k //= l
Q = k*P
if Q == E(0): continue
phi = E.isogeny(Q)
E,P = phi.codomain(), phi(P)
es[i] -= 1
E = E.quadratic_twist()
return montgomery_coefficient(E)
def get_cluster(coeff):
cluster = dict()
for count in range(1, 3):
for idxs in itertools.combinations(range(len(ls)), count):
for diffs in itertools.product([-1, 1], repeat=len(idxs)):
shift = [diffs[idxs.index(i)] if i in idxs else 0 for i in range(len(ls))]
cluster[int(csidh(coeff, shift))] = list(map(int, shift))
return coeff, cluster
def derive_graph(clusters):
graph = Graph()
graph.add_vertices(clusters)
shifts = dict()
for (coeff1, cluster1), (coeff2, cluster2) in itertools.combinations(clusters.items(), 2):
if collision := cluster1.keys() & cluster2.keys():
graph.add_edge(coeff1, coeff2)
coeff = collision.pop()
shift = vector(cluster1[coeff]) - vector(cluster2[coeff])
shifts[(coeff1, coeff2)] = shift
shifts[(coeff2, coeff1)] = -shift
if not graph.is_connected():
print('warning: graph is not connected')
return graph, shifts
def validate(center, graph, shifts):
errors = [collections.defaultdict(int) for _ in range(len(ls))]
for vertex in graph.get_vertices():
bias = samples.count(vertex)
if len(path := graph.shortest_path(center, vertex)) > 1:
shift = sum(shifts[edge] for edge in zip(path, path[1:]))
for e, s in zip(errors, shift):
e[s] += bias
if all(map(lambda e: min(e) == 0 or max(e) == 0, errors)):
return errors
def get_search_space(error):
choices = []
for e, l in zip(error, ls):
bound = len(samples)*(1/l + 1-(1-1/l)^2)/2
if 2 in e:
choices.append((-2,))
elif -2 in e:
choices.append((2,))
elif 1 in e:
choices.append((-2, -1) if e[1] > bound else (-1, -2))
elif -1 in e:
choices.append((2, 1) if e[-1] > bound else (1, 2))
else:
choices.append((0,))
print(f'Searching {prod(map(len, choices))} possibilities...')
return itertools.product(*choices)
def decrypt(priv):
secret = ','.join(f'{e:+}' for e in priv)
stream = SHA512.new(secret.encode()).digest()
for i, c in enumerate(b'hxp{'):
if enc[i] ^^ stream[i] != c:
return
return bytes(map(int.__xor__, enc, stream))
enc_hex = 'b7c4aa9ba15650fcb2e90c26b25b0d67e3bb175f5491bf056295f8d1f55ee398edfde6be3b3c'
enc = bytes.fromhex(enc_hex)
samples = json.loads(open(f'{enc_hex}/samples.json').read())
coeffs = set(samples)
with Pool(processes=4) as pool:
for coeff, cluster in tqdm(pool.imap(get_cluster, coeffs)):
with open(f'{enc_hex}/clusters/{coeff}.json', 'w') as f:
f.write(json.dumps(cluster))
clusters = {coeff: json.loads(open(f'{enc_hex}/clusters/{coeff}.json').read()) for coeff in tqdm(coeffs)}
graph, shifts = derive_graph(clusters)
save(graph, f'{enc_hex}/graph')
save(shifts, f'{enc_hex}/shifts')
graph = load(f'{enc_hex}/graph')
shifts = load(f'{enc_hex}/shifts')
for vertex in sorted(graph.get_vertices(), key=lambda u: -samples.count(u)):
if errors := validate(vertex, graph, shifts):
with Pool(processes=4) as pool:
for flag in filter(None, tqdm(pool.imap(decrypt, get_search_space(errors), 2^20))):
print(flag)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment