Skip to content

Instantly share code, notes, and snippets.

@Chrstm
Last active December 15, 2021 09:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Chrstm/4b831b4e795eecb99945169c56ec46b4 to your computer and use it in GitHub Desktop.
Save Chrstm/4b831b4e795eecb99945169c56ec46b4 to your computer and use it in GitHub Desktop.
ByteCTF 2021 Final Crypto - Acyclic Group
#!/usr/bin/env python3
from hashlib import sha256
from random import choices, getrandbits, randint
import signal
import string
from flag import FLAG
def proof_of_work() -> bool:
alphabet = string.ascii_letters + string.digits
nonce = "".join(choices(alphabet, k=8))
print(f'SHA256("{nonce}" + ?) starts with "000000"')
message = (nonce + input().strip()).encode()
return sha256(message).digest().hex().startswith("000000")
def gen_modulus() -> int:
primes = [
2, 3, 5, 7, 11, 13, 17, 19,
23, 29, 31, 37, 41, 43, 47, 53,
59, 61, 67, 71, 73, 79, 83, 89,
97, 101, 103, 107, 109, 113, 127, 131,
]
n = 1
for p in primes:
n *= p ** randint(1, 16)
return n
def test() -> bool:
n = gen_modulus()
e = randint(1, n)
for i in range(3):
try:
num = input().strip()
except EOFError:
return False
if not str.isnumeric(num):
return False
num = int(num)
if num == n:
return True
res = pow(num, e, n)
print(res)
return False
def main():
signal.alarm(60)
if not proof_of_work():
return
print("Listen...I have some acyclic groups for you..."
"No noise this time...God bless you get them...")
passed = 0
T = 256
signal.alarm(T)
for i in range(T):
print("Round", i)
if test():
passed += 1
print("GOOD SHOT, MY FRIEND")
else:
print("CALM DOWN, MY FRIEND")
if passed > T * 0.8:
print("CONGRATULATIONS", FLAG)
main()
import re
import socket
import time
from hashlib import sha256
from random import randint
from tqdm import tqdm
from typing import Callable, Union
def gcd(a, b):
while b:
a, b = b, a % b
return a
class Socket:
count: int = 0
def __init__(self, host, port):
self.s = socket.socket()
self.s.connect((host, port))
def send(self, m, output=False, end=b'\n'):
if isinstance(m, str):
m = m.encode('Latin1')
m += end
if output:
print("send:", m)
self.s.send(m)
def recv(self, rec_bytes=8192, output=False):
m = self.s.recv(rec_bytes).decode('Latin1')
if output:
print("recv:\n{}\n".format(m))
return m
def close(self):
self.s.close()
def pow(self):
print("solving POW ...")
start_time = time.time()
resp = self.recv()
prefix = re.findall(r'SHA256\("(\w{8})"', resp)[0]
for i in range(1000000000):
message = (prefix + str(i)).encode("Latin-1")
if sha256(message).digest().hex().startswith("000000"):
self.send(str(i))
print(f"POW cost {time.time() - start_time:.2f} seconds")
time.sleep(1)
return
primes = [
2, 3, 5, 7, 11, 13, 17, 19,
23, 29, 31, 37, 41, 43, 47, 53,
59, 61, 67, 71, 73, 79, 83, 89,
97, 101, 103, 107, 109, 113, 127, 131,
]
N = len(primes)
MAX_EXPONENT = 16
primes_product_max_exp = 1 # primes_product_max_exp = prod(primes[i]) ** MAX_EXPONENT
primes_exp = [] # primes_exp[i][j] = primes[i] ** j
for p in primes:
x = [1] * (MAX_EXPONENT + 1)
for j in range(1, MAX_EXPONENT + 1):
x[j] = x[j - 1] * p
primes_exp.append(x)
primes_product_max_exp *= x[MAX_EXPONENT]
def solve_round(oracle):
guess_exp = [MAX_EXPONENT] * N
results = []
odd, even = 0, 0
for i in range(N - 2, N):
m = (primes_product_max_exp // primes_exp[i][MAX_EXPONENT]) - 1
c: int = oracle(m)
results.append(c)
for j in range(N):
if i == j:
continue
exp = 0
for k in range(1, MAX_EXPONENT + 1):
# Check c % (primes[i] ** k) == -1 or 1
mod = primes_exp[j][k]
rem = c % mod
if rem != 1 and rem != mod - 1:
break
exp = k
if mod == 2:
# Cannot deduce the parity of `e` if mod is 2
continue
if rem == 1:
even += 1
if rem == mod - 1:
odd += 1
guess_exp[j] = min(guess_exp[j], exp)
if odd > even:
# `e` is odd. Send `-1` so that n = response + 1.
return oracle(primes_product_max_exp - 1) + 1
# `e` is even.
g = abs(results[0] - results[1])
res = 1
for i in range(N - 2):
res *= primes_exp[i][guess_exp[i]]
res = gcd(res, g)
for i in range(N - 2, N):
res *= primes_exp[i][guess_exp[i]]
return res
def solve_real_challenge():
T = 256
while True:
s: socket = Socket("localhost", 23334)
s.pow()
s.recv()
def oracle(x: int):
s.send(str(x))
return int(s.recv().strip())
passed = 0
resp = ""
start_time = time.time()
for _ in range(T):
modulus = solve_round(oracle)
s.send(str(modulus))
resp = s.recv()
if "GOOD SHOT" in resp:
passed += 1
s.close()
print(f"passed {passed}/{T}, cost {time.time() - start_time:.2f} seconds")
if "CONGRATULATIONS" in resp:
print(resp)
return
print("try again\n")
def benchmark_mock_test():
def gen_modulus() -> int:
n = 1
for i in range(N):
n *= primes_exp[i][randint(1, 16)]
return n
def gen_oracle(n: int) -> Callable[[int], Union[int, str]]:
e = randint(1, n)
def oracle(m: int):
return pow(m, e, n)
return oracle
t = time.time()
T = 10000
passed = 0
for _ in tqdm(range(T)):
n = gen_modulus()
o = gen_oracle(n)
res = solve_round(o)
if res == n:
passed += 1
print(f"{passed}/{T}: {passed / T:.3f}")
print(f"cost {time.time() - t:.1f} seconds")
# solve_real_challenge()
benchmark_mock_test()
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment