Created
June 27, 2024 16:53
-
-
Save theoremoon/341ea3bc0287a02a19bcd89ee021f9a0 to your computer and use it in GitHub Desktop.
solution script for IDEA (solve.py is cooperated with keymoon, exploration3.rs is written by y011d4
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use rayon::prelude::*; | |
use serde::{Deserialize, Serialize}; | |
use serde_json::Result; | |
fn _mul(x: u16, y: u16) -> u16 { | |
let mut x: u64 = x as u64; | |
let mut y: u64 = y as u64; | |
if x == 0 { | |
x = 2u64.pow(16); | |
} | |
if y == 0 { | |
y = 2u64.pow(16); | |
} | |
let z = (x * y % (2u64.pow(16) + 1)) % 2u64.pow(16); | |
z as u16 | |
} | |
fn _add(x: u16, y: u16) -> u16 { | |
x + y | |
} | |
fn _modpow(base: u16, exponent: usize, modulus: usize) -> u16 { | |
let mut base = base as u64; | |
let mut exponent = exponent as u64; | |
let modulus = modulus as u64; | |
let mut result = 1u64; | |
base %= modulus; | |
while exponent > 0 { | |
if exponent % 2 == 1 { | |
result = (result * base) % modulus; | |
} | |
exponent >>= 1; | |
base = (base * base) % modulus; | |
} | |
result as u16 | |
} | |
fn _inv(x: u16) -> u16 { | |
if x == 0 { | |
return 0; | |
} | |
_modpow(x, 65535, 2_usize.pow(16) + 1) | |
} | |
fn decomposition(x: u64) -> (u16, u16, u16, u16) { | |
let x0 = (x >> 48) as u16; | |
let x1 = (x >> 32) as u16; | |
let x2 = (x >> 16) as u16; | |
let x3 = x as u16; | |
(x0, x1, x2, x3) | |
} | |
#[derive(Serialize, Deserialize)] | |
struct Input { | |
k16: u16, | |
k17: u16, | |
k18: u16, | |
k19: u16, | |
k20: u16, | |
k21: u16, | |
pt: u64, | |
ct: u64, | |
} | |
#[derive(Serialize, Deserialize)] | |
struct Output { | |
key: u128, | |
} | |
struct IDEA { | |
rounds: usize, | |
keys: Vec<Vec<u16>>, | |
} | |
impl IDEA { | |
fn new(key: u128) -> Self { | |
let mut key = key; | |
let mut sub_keys: Vec<u16> = Vec::new(); | |
for i in 0..((3 + 1) * 6) { | |
sub_keys.push(((key >> (112 - 16 * (i % 8))) & 0xffff) as u16); | |
if i % 8 == 7 { | |
key = (key << 25) | (key >> 103); | |
} | |
} | |
let mut keys: Vec<Vec<u16>> = Vec::new(); | |
for i in 0..4 { | |
let round_keys = sub_keys[i * 6..(i + 1) * 6].to_vec(); | |
keys.push(round_keys); | |
} | |
IDEA { keys, rounds: 3 } | |
} | |
fn encrypt(&self, plaintxt: u64) -> u64 { | |
let mut x1 = (plaintxt >> 48) as u16; | |
let mut x2 = (plaintxt >> 32) as u16; | |
let mut x3 = (plaintxt >> 16) as u16; | |
let mut x4 = plaintxt as u16; | |
for i in 0..self.rounds { | |
let round_keys = &self.keys[i]; | |
let (k1, k2, k3, k4, k5, k6) = ( | |
round_keys[0], | |
round_keys[1], | |
round_keys[2], | |
round_keys[3], | |
round_keys[4], | |
round_keys[5], | |
); | |
(x1, x2, x3, x4) = (_mul(x1, k1), _add(x2, k2), _add(x3, k3), _mul(x4, k4)); | |
let t0 = _mul(k5, x1 ^ x3); | |
let t1 = _mul(k6, _add(t0, x2 ^ x4)); | |
let t2 = _add(t0, t1); | |
(x1, x2, x3, x4) = (x1 ^ t1, x3 ^ t1, x2 ^ t2, x4 ^ t2); | |
} | |
let round_keys = &self.keys[self.rounds]; | |
let (k1, k2, k3, k4, _k5, _k6) = ( | |
round_keys[0], | |
round_keys[1], | |
round_keys[2], | |
round_keys[3], | |
round_keys[4], | |
round_keys[5], | |
); | |
let (y1, y2, y3, y4) = (_mul(x1, k1), _add(x3, k2), _add(x2, k3), _mul(x4, k4)); | |
return ((y1 as u64) << 48) | ((y2 as u64) << 32) | ((y3 as u64) << 16) | y4 as u64; | |
} | |
} | |
fn main() -> Result<()> { | |
let args: Vec<String> = std::env::args().collect(); | |
if args.len() != 3 { | |
eprintln!("Usage: {} <input_path> <output_path>", args[0]); | |
std::process::exit(1); | |
} | |
let input_path = &args[1]; // /home/y011d4/googlectf/crypto/idea/data3 | |
let output_path = &args[2]; // /home/y011d4/googlectf/crypto/idea/output3.json | |
println!("input_path: {}", input_path); | |
println!("output_path: {}", output_path); | |
let data_str = std::fs::read_to_string(input_path).expect("Unable to read file"); | |
let data: Input = serde_json::from_str(&data_str)?; | |
let k16 = data.k16; | |
let k17 = data.k17; | |
let k18 = data.k18; | |
let k19 = data.k19; | |
let k20 = data.k20; | |
let k21 = data.k21; | |
(0..2_usize.pow(32)).into_par_iter().for_each(|x| { | |
let k0 = ((k20 & 0b11) << 14) | ((k21 & 0b1111111111111100) >> 2); | |
let k1 = ((k21 & 0b11) << 14) | (((x & 0b11111111111111000000000000000000) >> 18) as u16); | |
let k2 = ((x & 0b111111111111111100) >> 2) as u16; | |
let k3 = (((x & 0b11) << 14) as u16) | ((k16 & 0b1111111111111100) >> 2); | |
let k4 = ((k16 & 0b11) << 14) | ((k17 & 0b1111111111111100) >> 2); | |
let k5 = ((k17 & 0b11) << 14) | ((k18 & 0b1111111111111100) >> 2); | |
let k6 = ((k18 & 0b11) << 14) | ((k19 & 0b1111111111111100) >> 2); | |
let k7 = ((k19 & 0b11) << 14) | ((k20 & 0b1111111111111100) >> 2); | |
let key: u128 = (k0 as u128) << 112 | |
| (k1 as u128) << 96 | |
| (k2 as u128) << 80 | |
| (k3 as u128) << 64 | |
| (k4 as u128) << 48 | |
| (k5 as u128) << 32 | |
| (k6 as u128) << 16 | |
| (k7 as u128); | |
let idea = IDEA::new(key); | |
let plaintxt = data.pt; | |
let ciphertxt = data.ct; | |
if idea.encrypt(plaintxt) == ciphertxt { | |
println!("key: {}", key); | |
let output = Output { key }; | |
let output_str = serde_json::to_string(&output).unwrap(); | |
std::fs::write(output_path, output_str).expect("Unable to write file"); | |
} | |
}); | |
Ok(()) | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
#[test] | |
fn test_new() { | |
let key: u128 = 0x11112222333344445555666677778888; | |
let idea = IDEA::new(key); | |
assert_eq!( | |
idea.keys[0], | |
vec![0x1111, 0x2222, 0x3333, 0x4444, 0x5555, 0x6666] | |
); | |
assert_eq!( | |
idea.keys[3], | |
vec![0x9999, 0xddde, 0x2220, 0x4444, 0x8888, 0xcccd] | |
); | |
} | |
#[test] | |
fn test_encrypt() { | |
let key: u128 = 0x11112222333344445555666677778888; | |
let idea = IDEA::new(key); | |
let plaintxt: u64 = 0x123456789abcdef0; | |
let ciphertxt = idea.encrypt(plaintxt); | |
assert_eq!(ciphertxt, 0x39ae247c0d1b0e); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"encoding/json" | |
"fmt" | |
"os" | |
) | |
func _mul(x, y uint16) uint16 { | |
xx := uint32(x) | |
yy := uint32(y) | |
if x == 0 { | |
xx = 1 << 16 | |
} | |
if y == 0 { | |
yy = 1 << 16 | |
} | |
z := xx * yy % (1<<16 + 1) | |
return uint16(z % (1 << 16)) | |
} | |
func egcd(a, b int) (int, int, int) { | |
if a == 0 { | |
return b, 0, 1 | |
} | |
g, x, y := egcd(b%a, a) | |
return g, y - (b/a)*x, x | |
} | |
func _inv(v uint16) uint16 { | |
m := 65537 | |
g, x, _ := egcd(int(v), m) | |
if g != 1 { | |
return 0 | |
} | |
return uint16((m + x) % m) | |
} | |
func _add(x, y uint16) uint16 { | |
return (x + y) & 0xffff | |
} | |
func _sub(x, y uint16) uint16 { | |
return (x - y) & 0xffff | |
} | |
func main() { | |
var data struct { | |
Pairs [][][]uint16 `json:"pairs"` | |
K18 int `json:"k18"` | |
K19 int `json:"k19"` | |
K20 int `json:"k20"` | |
K21 int `json:"k21"` | |
} | |
err := json.Unmarshal([]byte(os.Args[1]), &data) | |
if err != nil { | |
fmt.Println(err) | |
return | |
} | |
k18 := uint16(data.K18) | |
k19 := uint16(data.K19) | |
k20 := uint16(data.K20) | |
k21 := uint16(data.K21) | |
k18inv := _inv(k18) | |
k21inv := _inv(k21) | |
for k16 := 0; k16 < 65536; k16++ { | |
for k17 := 0; k17 < 65536; k17++ { | |
ok := true | |
for _, ct := range data.Pairs { | |
x1 := _mul(ct[0][0], k18inv) | |
x1_ := _mul(ct[1][0], k18inv) | |
x2 := _sub(ct[0][1], k19) | |
x2_ := _sub(ct[1][1], k19) | |
q := _sub(ct[0][1], k19) ^ _mul(ct[0][3], k21inv) | |
q_ := _sub(ct[1][1], k19) ^ _mul(ct[1][3], k21inv) | |
p := _mul(ct[0][0], k18inv) ^ _sub(ct[0][2], k20) | |
s := _mul(p, uint16(k16)) | |
t := _mul(_add(q, s), uint16(k17)) | |
t_ := _mul(_add(q_, s), uint16(k17)) | |
u := _add(s, t) | |
u_ := _add(s, t_) | |
if !((x1^t == x1_^t_) && (x2^u == x2_^u_)) { | |
ok = false | |
break | |
} | |
} | |
if ok { | |
fmt.Printf("%d %d\n", k16, k17) | |
} | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import time | |
from typing import Deque | |
from ptrlib import Socket | |
from itertools import combinations | |
from collections import Counter, deque | |
def _mul(x, y): | |
if x == 0: | |
x = 2**16 | |
if y == 0: | |
y = 2**16 | |
z = x * y % (2**16 + 1) | |
return z % 2**16 | |
def _add(x, y): | |
return (x + y) & 0xffff | |
def _sub(x, y): | |
return (x - y) & 0xffff | |
def _inv(x): | |
if x == 0: | |
return pow(2**16, -1, 2**16 + 1) | |
return pow(x, -1, 2**16 + 1) | |
def decompose(v): | |
x1 = (v >> 48) & 0xffff | |
x2 = (v >> 32) & 0xffff | |
x3 = (v >> 16) & 0xffff | |
x4 = v & 0xffff | |
return x1, x2, x3, x4 | |
def search_k18(pairs, pairs2): | |
k18_cands = [] | |
for k18 in range(2**16): | |
k18inv = _inv(k18) | |
flag = True | |
for i in range(N): | |
_, ct = pairs[i] | |
_, ct_ = pairs2[i] | |
y1, _, y3, _ = decompose(ct) | |
y1_, _, y3_, _ = decompose(ct_) | |
tdiff = _mul(y1, k18inv) ^ _mul(y1_, k18inv) | |
tdiff2 = y3 ^ y3_ | |
if tdiff & 1 != tdiff2 & 1: | |
flag = False | |
break | |
if flag: | |
k18_cands.append(k18) | |
return k18_cands | |
def search_k20(pairs, pairs2, k18): | |
k18inv = _inv(k18) | |
candidates = [] | |
for k20 in range(2**16): | |
flag = True | |
for i in range(N): | |
_, ct = pairs[i] | |
_, ct_ = pairs2[i] | |
y1, _, y3, _ = decompose(ct) | |
y1_, _, y3_, _ = decompose(ct_) | |
tdiff = _mul(y1, k18inv) ^ _mul(y1_, k18inv) | |
tdiff2 = _sub(y3, k20) ^ _sub(y3_, k20) | |
if tdiff != tdiff2: | |
flag = False | |
break | |
if flag: | |
candidates.append(k20) | |
return candidates | |
def check_is_good_pair(pair, pair3, k18inv, k20): | |
_, ct = pair | |
_, ct_ = pair3 | |
y1, _, y3, _ = decompose(ct) | |
y1_, _, y3_, _ = decompose(ct_) | |
ydiff = _mul(y1, k18inv) ^ _mul(y1_, k18inv) | |
ydiff2 = _sub(y3, k20) ^ _sub(y3_, k20 ^ 2) # 127のときのk20は差分があるので考慮 | |
# return (ydiff ^ ydiff2) in [0x100, 0x300, 0x700, 0xf00, 0x1f00, 0x3f00, 0x7f00, 0xff00] | |
return (ydiff ^ ydiff2) in [0x100, 0x300] | |
def search_k21(good_pairs, k18, k20): | |
candidates = set() | |
k19_cands = set() | |
# good pairs から2個選んでくる。その2個が真にgoodであることを祈りながら | |
for comb in combinations(good_pairs, 2): | |
table = {} | |
for k21 in range(2**16): | |
k21inv = _inv(k21) | |
udiffs = [] | |
for _, ct, _, ct_ in comb: | |
_, _, _, y4 = decompose(ct) | |
_, _, _, y4_ = decompose(ct_) | |
udiff = _mul(y4, k21inv) ^ _mul(y4_, k21inv) | |
udiffs.append(udiff) | |
key = tuple(udiffs) | |
table[key] = k21 | |
for k19 in range(2**16): | |
udiffs = [] | |
for _, ct, _, ct_ in comb: | |
_, y2, _, _ = decompose(ct) | |
_, y2_, _, _ = decompose(ct_) | |
udiff = _sub(y2, k19) ^ _sub(y2_, k19) | |
udiffs.append(udiff) | |
key = tuple(udiffs) | |
if key in table: | |
k21 = table[key] | |
candidates.add(k21) | |
k19_cands.add(k19) | |
return candidates | |
def search_k19(pairs, k21): | |
k21inv = _inv(k21) | |
counter = Counter() | |
for k19 in range(2**16): | |
for _, ct, _, ct_ in pairs: | |
_, y2, _, y4 = decompose(ct) | |
_, y2_, _, y4_ = decompose(ct_) | |
udiff = _mul(y4, k21inv) ^ _mul(y4_, k21inv) | |
udiff2 = _sub(y2, k19) ^ _sub(y2_, k19) | |
if udiff == udiff2: | |
counter[k19] += 1 | |
common = counter.most_common() | |
res = [] | |
for k, v in common: | |
if v == common[0][1]: | |
res.append(k) | |
else: | |
break | |
return res | |
q: Deque[int] = deque() | |
sock: Socket = None | |
def reconnect(): | |
global sock, start | |
if sock is not None: sock.close() | |
start = time.time() | |
q.clear() | |
sock = Socket("idea.2024.ctfcompetition.com", 1337) | |
import socket | |
sock._sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) | |
sock._sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_QUICKACK, 1) | |
def elapsed(msg): | |
print(msg, 'elapsed:', time.time() - start) | |
while True: | |
try: | |
print("\n\nstart new search") | |
reconnect() | |
def encrypt_send(pt: int): | |
sock.sendline("1") | |
sock.sendline(hex(pt)) | |
q.append(pt) | |
def encrypt_recv(): | |
sock.recvuntil("text: ") | |
pt = q.popleft() | |
return pt, int(sock.recvline().strip()) | |
def mask(mask: int): | |
sock.sendline("2") | |
sock.sendlineafter("mask: ", hex(mask)) | |
# step1. find k18, k20 | |
N = 20 # step1で使う数 | |
M = 400 # step2で使う数 | |
base = random.randrange(0, 2**48) | |
pairs = [] | |
for i in range(M): | |
pt = (random.randrange(0, 2**16) << 48) | base | |
encrypt_send(pt) | |
for i in range(M): | |
pairs.append(encrypt_recv()) | |
elapsed("recv first wave") | |
mask(2**103) | |
pairs2 = [] | |
for i in range(N): | |
pt, ct = pairs[i] | |
pt_ = pt + 2**(16 + 16 + 7) | |
encrypt_send(pt_) | |
for i in range(N): | |
pairs2.append(encrypt_recv()) | |
elapsed("recv second wave") | |
k18_cands = search_k18(pairs, pairs2) | |
elapsed("k18 found") | |
if len(k18_cands) == 0: | |
# - のパターンをやる | |
for i in range(N): | |
pt, ct = pairs[i] | |
pt_ = pt - 2**(16 + 16 + 7) | |
ct_ = encrypt_send(pt_) | |
for i in range(N): | |
pairs2[i] = encrypt_recv() | |
k18_cands = search_k18(pairs, pairs2) | |
assert len(k18_cands) > 0, "k18 not found" | |
for k18 in k18_cands: | |
k20_cands = search_k20(pairs, pairs2, k18) | |
if len(k20_cands) > 0: | |
k20 = k20_cands[1] # 決め打ち | |
break | |
assert len(k20_cands) > 0, "k20 not found" | |
elapsed("k20 found") | |
print("k18:", k18) | |
print("k20:", k20_cands) | |
# step2. find k21, k19 | |
mask((2**103) ^ (2**127)) | |
good_pairs = [] # k0 のdiff消せてるペア。偽陽性がある | |
k18inv = _inv(k18) | |
for _ in range(M): | |
pt_ = (random.randrange(0, 2**16) << 48) | base | |
encrypt_send(pt_) | |
for _ in range(M): | |
pt_, ct_ = encrypt_recv() | |
for i in range(M): | |
pt, ct = pairs[i] | |
if check_is_good_pair((pt, ct), (pt_, ct_), k18inv, k20): | |
good_pairs.append([pt, ct, pt_, ct_]) | |
# あるptに対してgoodなpt_は1つだけしかないのでbreakしてよい | |
break | |
if len(good_pairs) >= 10: | |
break | |
q.clear() # break するので | |
assert len(good_pairs) >= 4, "good_pairs not found" | |
elapsed("good pair found") | |
k21_cands = search_k21(good_pairs[:5], k18, k20) | |
assert len(k21_cands) > 0, "k21 not found" | |
print("k21:", k21_cands) | |
elapsed("k21 found") | |
pairs_for_k16 = [] | |
for i in range(N): | |
_, ct = pairs[i] | |
_, ct_ = pairs2[i] | |
y1, y2, y3, y4 = decompose(ct) | |
y1_, y2_, y3_, y4_ = decompose(ct_) | |
pairs_for_k16.append([[y1, y2, y3, y4], [y1_, y2_, y3_, y4_]]) | |
import json | |
args = [] | |
for k21 in k21_cands: | |
k19_cands = search_k19(good_pairs, k21) | |
print("k19:", k19_cands) | |
for k19 in k19_cands: | |
args.append(json.dumps({ | |
"pairs": pairs_for_k16, | |
"k18": k18, | |
"k19": k19, | |
"k20": k20, | |
"k21": k21, | |
})) | |
print('len:', len(args)) | |
elapsed("start processing") | |
import subprocess | |
import concurrent.futures | |
from multiprocessing import Pool | |
def solve_k16k17(arg): | |
output = subprocess.run(["./go_solver/k16k17", arg], capture_output=True) | |
k16k17 = output.stdout.decode().strip() | |
if k16k17 != "": | |
return arg, k16k17 | |
found_k = False | |
with concurrent.futures.ProcessPoolExecutor() as executor: | |
futures = {executor.submit(solve_k16k17, arg): arg for arg in args} | |
for future in concurrent.futures.as_completed(futures): | |
try: | |
res = future.result() | |
print(res) | |
if res is None: continue | |
arg, k16k17 = res | |
found_k = True | |
elapsed("k16k17 found") | |
break | |
except Exception as exc: | |
print(f"Task {futures[future]} generated an exception: {exc}") | |
executor.shutdown(wait=False) # 他のプロセスを終了させる | |
elapsed("end") | |
if not found_k: | |
print(":(") | |
continue | |
print(arg) | |
parsed = json.loads(arg) | |
k19 = parsed["k19"] | |
k21 = parsed["k21"] | |
k16, k17 = map(int, k16k17.split()) | |
print("k16k17:", k16k17) | |
print("k19:", k19) | |
print("k21:", k21) | |
pt, ct = pairs[0] | |
ALL_KEY_PATH = "exploration3/target/release/exploration3" | |
INPUT3_PATH = "./input3.json" | |
OUTPUT3_PATH = "./output3.json" | |
with open(INPUT3_PATH, "w") as fp: | |
json.dump( | |
{ | |
# これらがすでに求まっていることを期待 | |
"k16": k16, | |
"k17": k17, | |
"k18": k18, | |
"k19": k19, | |
"k20": k20, | |
"k21": k21, | |
"pt": pt, | |
"ct": ct, # encrypt(pt) == ct となるペアをひとつ用意 (どちらも int) | |
}, | |
fp, | |
) | |
elapsed("start all key") | |
p = subprocess.run([ALL_KEY_PATH, INPUT3_PATH, OUTPUT3_PATH]) | |
elapsed("done") | |
try: | |
with open(OUTPUT3_PATH, "r") as fp: | |
output = json.load(fp) | |
except FileNotFoundError: | |
print("Fail...") | |
continue | |
print(output) | |
key = hex(output["key"])[2:] | |
sock.sendlineafter("Get balance", "3") | |
sock.sendlineafter("key_guess:", key) | |
sock.interactive() | |
break | |
except KeyboardInterrupt: | |
break | |
except Exception as e: | |
print(e) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment