# Chrstm/acyclic_group.py

Last active Dec 15, 2021
ByteCTF 2021 Final Crypto - Acyclic Group
 #!/usr/bin/env python3 from hashlib import sha256 from random import choices, getrandbits, randint import signal import string from flag import FLAG def proof_of_work() -> bool: alphabet = string.ascii_letters + string.digits nonce = "".join(choices(alphabet, k=8)) print(f'SHA256("{nonce}" + ?) starts with "000000"') message = (nonce + input().strip()).encode() return sha256(message).digest().hex().startswith("000000") def gen_modulus() -> int: primes = [ 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, ] n = 1 for p in primes: n *= p ** randint(1, 16) return n def test() -> bool: n = gen_modulus() e = randint(1, n) for i in range(3): try: num = input().strip() except EOFError: return False if not str.isnumeric(num): return False num = int(num) if num == n: return True res = pow(num, e, n) print(res) return False def main(): signal.alarm(60) if not proof_of_work(): return print("Listen...I have some acyclic groups for you..." "No noise this time...God bless you get them...") passed = 0 T = 256 signal.alarm(T) for i in range(T): print("Round", i) if test(): passed += 1 print("GOOD SHOT, MY FRIEND") else: print("CALM DOWN, MY FRIEND") if passed > T * 0.8: print("CONGRATULATIONS", FLAG) main()
 import re import socket import time from hashlib import sha256 from random import randint from tqdm import tqdm from typing import Callable, Union def gcd(a, b): while b: a, b = b, a % b return a class Socket: count: int = 0 def __init__(self, host, port): self.s = socket.socket() self.s.connect((host, port)) def send(self, m, output=False, end=b'\n'): if isinstance(m, str): m = m.encode('Latin1') m += end if output: print("send:", m) self.s.send(m) def recv(self, rec_bytes=8192, output=False): m = self.s.recv(rec_bytes).decode('Latin1') if output: print("recv:\n{}\n".format(m)) return m def close(self): self.s.close() def pow(self): print("solving POW ...") start_time = time.time() resp = self.recv() prefix = re.findall(r'SHA256\("(\w{8})"', resp)[0] for i in range(1000000000): message = (prefix + str(i)).encode("Latin-1") if sha256(message).digest().hex().startswith("000000"): self.send(str(i)) print(f"POW cost {time.time() - start_time:.2f} seconds") time.sleep(1) return primes = [ 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, ] N = len(primes) MAX_EXPONENT = 16 primes_product_max_exp = 1 # primes_product_max_exp = prod(primes[i]) ** MAX_EXPONENT primes_exp = [] # primes_exp[i][j] = primes[i] ** j for p in primes: x = [1] * (MAX_EXPONENT + 1) for j in range(1, MAX_EXPONENT + 1): x[j] = x[j - 1] * p primes_exp.append(x) primes_product_max_exp *= x[MAX_EXPONENT] def solve_round(oracle): guess_exp = [MAX_EXPONENT] * N results = [] odd, even = 0, 0 for i in range(N - 2, N): m = (primes_product_max_exp // primes_exp[i][MAX_EXPONENT]) - 1 c: int = oracle(m) results.append(c) for j in range(N): if i == j: continue exp = 0 for k in range(1, MAX_EXPONENT + 1): # Check c % (primes[i] ** k) == -1 or 1 mod = primes_exp[j][k] rem = c % mod if rem != 1 and rem != mod - 1: break exp = k if mod == 2: # Cannot deduce the parity of `e` if mod is 2 continue if rem == 1: even += 1 if rem == mod - 1: odd += 1 guess_exp[j] = min(guess_exp[j], exp) if odd > even: # `e` is odd. Send `-1` so that n = response + 1. return oracle(primes_product_max_exp - 1) + 1 # `e` is even. g = abs(results[0] - results[1]) res = 1 for i in range(N - 2): res *= primes_exp[i][guess_exp[i]] res = gcd(res, g) for i in range(N - 2, N): res *= primes_exp[i][guess_exp[i]] return res def solve_real_challenge(): T = 256 while True: s: socket = Socket("localhost", 23334) s.pow() s.recv() def oracle(x: int): s.send(str(x)) return int(s.recv().strip()) passed = 0 resp = "" start_time = time.time() for _ in range(T): modulus = solve_round(oracle) s.send(str(modulus)) resp = s.recv() if "GOOD SHOT" in resp: passed += 1 s.close() print(f"passed {passed}/{T}, cost {time.time() - start_time:.2f} seconds") if "CONGRATULATIONS" in resp: print(resp) return print("try again\n") def benchmark_mock_test(): def gen_modulus() -> int: n = 1 for i in range(N): n *= primes_exp[i][randint(1, 16)] return n def gen_oracle(n: int) -> Callable[[int], Union[int, str]]: e = randint(1, n) def oracle(m: int): return pow(m, e, n) return oracle t = time.time() T = 10000 passed = 0 for _ in tqdm(range(T)): n = gen_modulus() o = gen_oracle(n) res = solve_round(o) if res == n: passed += 1 print(f"{passed}/{T}: {passed / T:.3f}") print(f"cost {time.time() - t:.1f} seconds") # solve_real_challenge() benchmark_mock_test()