{{ message }}

Instantly share code, notes, and snippets.

# Chrstm/acyclic_group.py

Last active Dec 15, 2021
ByteCTF 2021 Final Crypto - Acyclic Group
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 #!/usr/bin/env python3 from hashlib import sha256 from random import choices, getrandbits, randint import signal import string from flag import FLAG def proof_of_work() -> bool: alphabet = string.ascii_letters + string.digits nonce = "".join(choices(alphabet, k=8)) print(f'SHA256("{nonce}" + ?) starts with "000000"') message = (nonce + input().strip()).encode() return sha256(message).digest().hex().startswith("000000") def gen_modulus() -> int: primes = [ 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, ] n = 1 for p in primes: n *= p ** randint(1, 16) return n def test() -> bool: n = gen_modulus() e = randint(1, n) for i in range(3): try: num = input().strip() except EOFError: return False if not str.isnumeric(num): return False num = int(num) if num == n: return True res = pow(num, e, n) print(res) return False def main(): signal.alarm(60) if not proof_of_work(): return print("Listen...I have some acyclic groups for you..." "No noise this time...God bless you get them...") passed = 0 T = 256 signal.alarm(T) for i in range(T): print("Round", i) if test(): passed += 1 print("GOOD SHOT, MY FRIEND") else: print("CALM DOWN, MY FRIEND") if passed > T * 0.8: print("CONGRATULATIONS", FLAG) main()
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 import re import socket import time from hashlib import sha256 from random import randint from tqdm import tqdm from typing import Callable, Union def gcd(a, b): while b: a, b = b, a % b return a class Socket: count: int = 0 def __init__(self, host, port): self.s = socket.socket() self.s.connect((host, port)) def send(self, m, output=False, end=b'\n'): if isinstance(m, str): m = m.encode('Latin1') m += end if output: print("send:", m) self.s.send(m) def recv(self, rec_bytes=8192, output=False): m = self.s.recv(rec_bytes).decode('Latin1') if output: print("recv:\n{}\n".format(m)) return m def close(self): self.s.close() def pow(self): print("solving POW ...") start_time = time.time() resp = self.recv() prefix = re.findall(r'SHA256\("(\w{8})"', resp)[0] for i in range(1000000000): message = (prefix + str(i)).encode("Latin-1") if sha256(message).digest().hex().startswith("000000"): self.send(str(i)) print(f"POW cost {time.time() - start_time:.2f} seconds") time.sleep(1) return primes = [ 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, ] N = len(primes) MAX_EXPONENT = 16 primes_product_max_exp = 1 # primes_product_max_exp = prod(primes[i]) ** MAX_EXPONENT primes_exp = [] # primes_exp[i][j] = primes[i] ** j for p in primes: x = [1] * (MAX_EXPONENT + 1) for j in range(1, MAX_EXPONENT + 1): x[j] = x[j - 1] * p primes_exp.append(x) primes_product_max_exp *= x[MAX_EXPONENT] def solve_round(oracle): guess_exp = [MAX_EXPONENT] * N results = [] odd, even = 0, 0 for i in range(N - 2, N): m = (primes_product_max_exp // primes_exp[i][MAX_EXPONENT]) - 1 c: int = oracle(m) results.append(c) for j in range(N): if i == j: continue exp = 0 for k in range(1, MAX_EXPONENT + 1): # Check c % (primes[i] ** k) == -1 or 1 mod = primes_exp[j][k] rem = c % mod if rem != 1 and rem != mod - 1: break exp = k if mod == 2: # Cannot deduce the parity of `e` if mod is 2 continue if rem == 1: even += 1 if rem == mod - 1: odd += 1 guess_exp[j] = min(guess_exp[j], exp) if odd > even: # `e` is odd. Send `-1` so that n = response + 1. return oracle(primes_product_max_exp - 1) + 1 # `e` is even. g = abs(results[0] - results[1]) res = 1 for i in range(N - 2): res *= primes_exp[i][guess_exp[i]] res = gcd(res, g) for i in range(N - 2, N): res *= primes_exp[i][guess_exp[i]] return res def solve_real_challenge(): T = 256 while True: s: socket = Socket("localhost", 23334) s.pow() s.recv() def oracle(x: int): s.send(str(x)) return int(s.recv().strip()) passed = 0 resp = "" start_time = time.time() for _ in range(T): modulus = solve_round(oracle) s.send(str(modulus)) resp = s.recv() if "GOOD SHOT" in resp: passed += 1 s.close() print(f"passed {passed}/{T}, cost {time.time() - start_time:.2f} seconds") if "CONGRATULATIONS" in resp: print(resp) return print("try again\n") def benchmark_mock_test(): def gen_modulus() -> int: n = 1 for i in range(N): n *= primes_exp[i][randint(1, 16)] return n def gen_oracle(n: int) -> Callable[[int], Union[int, str]]: e = randint(1, n) def oracle(m: int): return pow(m, e, n) return oracle t = time.time() T = 10000 passed = 0 for _ in tqdm(range(T)): n = gen_modulus() o = gen_oracle(n) res = solve_round(o) if res == n: passed += 1 print(f"{passed}/{T}: {passed / T:.3f}") print(f"cost {time.time() - t:.1f} seconds") # solve_real_challenge() benchmark_mock_test()