Skip to content

Instantly share code, notes, and snippets.

@theoremoon
Created June 27, 2024 16:53
Show Gist options
  • Save theoremoon/341ea3bc0287a02a19bcd89ee021f9a0 to your computer and use it in GitHub Desktop.
Save theoremoon/341ea3bc0287a02a19bcd89ee021f9a0 to your computer and use it in GitHub Desktop.
solution script for IDEA (solve.py is cooperated with keymoon, exploration3.rs is written by y011d4
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use serde_json::Result;
fn _mul(x: u16, y: u16) -> u16 {
let mut x: u64 = x as u64;
let mut y: u64 = y as u64;
if x == 0 {
x = 2u64.pow(16);
}
if y == 0 {
y = 2u64.pow(16);
}
let z = (x * y % (2u64.pow(16) + 1)) % 2u64.pow(16);
z as u16
}
fn _add(x: u16, y: u16) -> u16 {
x + y
}
fn _modpow(base: u16, exponent: usize, modulus: usize) -> u16 {
let mut base = base as u64;
let mut exponent = exponent as u64;
let modulus = modulus as u64;
let mut result = 1u64;
base %= modulus;
while exponent > 0 {
if exponent % 2 == 1 {
result = (result * base) % modulus;
}
exponent >>= 1;
base = (base * base) % modulus;
}
result as u16
}
fn _inv(x: u16) -> u16 {
if x == 0 {
return 0;
}
_modpow(x, 65535, 2_usize.pow(16) + 1)
}
fn decomposition(x: u64) -> (u16, u16, u16, u16) {
let x0 = (x >> 48) as u16;
let x1 = (x >> 32) as u16;
let x2 = (x >> 16) as u16;
let x3 = x as u16;
(x0, x1, x2, x3)
}
#[derive(Serialize, Deserialize)]
struct Input {
k16: u16,
k17: u16,
k18: u16,
k19: u16,
k20: u16,
k21: u16,
pt: u64,
ct: u64,
}
#[derive(Serialize, Deserialize)]
struct Output {
key: u128,
}
struct IDEA {
rounds: usize,
keys: Vec<Vec<u16>>,
}
impl IDEA {
fn new(key: u128) -> Self {
let mut key = key;
let mut sub_keys: Vec<u16> = Vec::new();
for i in 0..((3 + 1) * 6) {
sub_keys.push(((key >> (112 - 16 * (i % 8))) & 0xffff) as u16);
if i % 8 == 7 {
key = (key << 25) | (key >> 103);
}
}
let mut keys: Vec<Vec<u16>> = Vec::new();
for i in 0..4 {
let round_keys = sub_keys[i * 6..(i + 1) * 6].to_vec();
keys.push(round_keys);
}
IDEA { keys, rounds: 3 }
}
fn encrypt(&self, plaintxt: u64) -> u64 {
let mut x1 = (plaintxt >> 48) as u16;
let mut x2 = (plaintxt >> 32) as u16;
let mut x3 = (plaintxt >> 16) as u16;
let mut x4 = plaintxt as u16;
for i in 0..self.rounds {
let round_keys = &self.keys[i];
let (k1, k2, k3, k4, k5, k6) = (
round_keys[0],
round_keys[1],
round_keys[2],
round_keys[3],
round_keys[4],
round_keys[5],
);
(x1, x2, x3, x4) = (_mul(x1, k1), _add(x2, k2), _add(x3, k3), _mul(x4, k4));
let t0 = _mul(k5, x1 ^ x3);
let t1 = _mul(k6, _add(t0, x2 ^ x4));
let t2 = _add(t0, t1);
(x1, x2, x3, x4) = (x1 ^ t1, x3 ^ t1, x2 ^ t2, x4 ^ t2);
}
let round_keys = &self.keys[self.rounds];
let (k1, k2, k3, k4, _k5, _k6) = (
round_keys[0],
round_keys[1],
round_keys[2],
round_keys[3],
round_keys[4],
round_keys[5],
);
let (y1, y2, y3, y4) = (_mul(x1, k1), _add(x3, k2), _add(x2, k3), _mul(x4, k4));
return ((y1 as u64) << 48) | ((y2 as u64) << 32) | ((y3 as u64) << 16) | y4 as u64;
}
}
fn main() -> Result<()> {
let args: Vec<String> = std::env::args().collect();
if args.len() != 3 {
eprintln!("Usage: {} <input_path> <output_path>", args[0]);
std::process::exit(1);
}
let input_path = &args[1]; // /home/y011d4/googlectf/crypto/idea/data3
let output_path = &args[2]; // /home/y011d4/googlectf/crypto/idea/output3.json
println!("input_path: {}", input_path);
println!("output_path: {}", output_path);
let data_str = std::fs::read_to_string(input_path).expect("Unable to read file");
let data: Input = serde_json::from_str(&data_str)?;
let k16 = data.k16;
let k17 = data.k17;
let k18 = data.k18;
let k19 = data.k19;
let k20 = data.k20;
let k21 = data.k21;
(0..2_usize.pow(32)).into_par_iter().for_each(|x| {
let k0 = ((k20 & 0b11) << 14) | ((k21 & 0b1111111111111100) >> 2);
let k1 = ((k21 & 0b11) << 14) | (((x & 0b11111111111111000000000000000000) >> 18) as u16);
let k2 = ((x & 0b111111111111111100) >> 2) as u16;
let k3 = (((x & 0b11) << 14) as u16) | ((k16 & 0b1111111111111100) >> 2);
let k4 = ((k16 & 0b11) << 14) | ((k17 & 0b1111111111111100) >> 2);
let k5 = ((k17 & 0b11) << 14) | ((k18 & 0b1111111111111100) >> 2);
let k6 = ((k18 & 0b11) << 14) | ((k19 & 0b1111111111111100) >> 2);
let k7 = ((k19 & 0b11) << 14) | ((k20 & 0b1111111111111100) >> 2);
let key: u128 = (k0 as u128) << 112
| (k1 as u128) << 96
| (k2 as u128) << 80
| (k3 as u128) << 64
| (k4 as u128) << 48
| (k5 as u128) << 32
| (k6 as u128) << 16
| (k7 as u128);
let idea = IDEA::new(key);
let plaintxt = data.pt;
let ciphertxt = data.ct;
if idea.encrypt(plaintxt) == ciphertxt {
println!("key: {}", key);
let output = Output { key };
let output_str = serde_json::to_string(&output).unwrap();
std::fs::write(output_path, output_str).expect("Unable to write file");
}
});
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
let key: u128 = 0x11112222333344445555666677778888;
let idea = IDEA::new(key);
assert_eq!(
idea.keys[0],
vec![0x1111, 0x2222, 0x3333, 0x4444, 0x5555, 0x6666]
);
assert_eq!(
idea.keys[3],
vec![0x9999, 0xddde, 0x2220, 0x4444, 0x8888, 0xcccd]
);
}
#[test]
fn test_encrypt() {
let key: u128 = 0x11112222333344445555666677778888;
let idea = IDEA::new(key);
let plaintxt: u64 = 0x123456789abcdef0;
let ciphertxt = idea.encrypt(plaintxt);
assert_eq!(ciphertxt, 0x39ae247c0d1b0e);
}
}
package main
import (
"encoding/json"
"fmt"
"os"
)
func _mul(x, y uint16) uint16 {
xx := uint32(x)
yy := uint32(y)
if x == 0 {
xx = 1 << 16
}
if y == 0 {
yy = 1 << 16
}
z := xx * yy % (1<<16 + 1)
return uint16(z % (1 << 16))
}
func egcd(a, b int) (int, int, int) {
if a == 0 {
return b, 0, 1
}
g, x, y := egcd(b%a, a)
return g, y - (b/a)*x, x
}
func _inv(v uint16) uint16 {
m := 65537
g, x, _ := egcd(int(v), m)
if g != 1 {
return 0
}
return uint16((m + x) % m)
}
func _add(x, y uint16) uint16 {
return (x + y) & 0xffff
}
func _sub(x, y uint16) uint16 {
return (x - y) & 0xffff
}
func main() {
var data struct {
Pairs [][][]uint16 `json:"pairs"`
K18 int `json:"k18"`
K19 int `json:"k19"`
K20 int `json:"k20"`
K21 int `json:"k21"`
}
err := json.Unmarshal([]byte(os.Args[1]), &data)
if err != nil {
fmt.Println(err)
return
}
k18 := uint16(data.K18)
k19 := uint16(data.K19)
k20 := uint16(data.K20)
k21 := uint16(data.K21)
k18inv := _inv(k18)
k21inv := _inv(k21)
for k16 := 0; k16 < 65536; k16++ {
for k17 := 0; k17 < 65536; k17++ {
ok := true
for _, ct := range data.Pairs {
x1 := _mul(ct[0][0], k18inv)
x1_ := _mul(ct[1][0], k18inv)
x2 := _sub(ct[0][1], k19)
x2_ := _sub(ct[1][1], k19)
q := _sub(ct[0][1], k19) ^ _mul(ct[0][3], k21inv)
q_ := _sub(ct[1][1], k19) ^ _mul(ct[1][3], k21inv)
p := _mul(ct[0][0], k18inv) ^ _sub(ct[0][2], k20)
s := _mul(p, uint16(k16))
t := _mul(_add(q, s), uint16(k17))
t_ := _mul(_add(q_, s), uint16(k17))
u := _add(s, t)
u_ := _add(s, t_)
if !((x1^t == x1_^t_) && (x2^u == x2_^u_)) {
ok = false
break
}
}
if ok {
fmt.Printf("%d %d\n", k16, k17)
}
}
}
}
import random
import time
from typing import Deque
from ptrlib import Socket
from itertools import combinations
from collections import Counter, deque
def _mul(x, y):
if x == 0:
x = 2**16
if y == 0:
y = 2**16
z = x * y % (2**16 + 1)
return z % 2**16
def _add(x, y):
return (x + y) & 0xffff
def _sub(x, y):
return (x - y) & 0xffff
def _inv(x):
if x == 0:
return pow(2**16, -1, 2**16 + 1)
return pow(x, -1, 2**16 + 1)
def decompose(v):
x1 = (v >> 48) & 0xffff
x2 = (v >> 32) & 0xffff
x3 = (v >> 16) & 0xffff
x4 = v & 0xffff
return x1, x2, x3, x4
def search_k18(pairs, pairs2):
k18_cands = []
for k18 in range(2**16):
k18inv = _inv(k18)
flag = True
for i in range(N):
_, ct = pairs[i]
_, ct_ = pairs2[i]
y1, _, y3, _ = decompose(ct)
y1_, _, y3_, _ = decompose(ct_)
tdiff = _mul(y1, k18inv) ^ _mul(y1_, k18inv)
tdiff2 = y3 ^ y3_
if tdiff & 1 != tdiff2 & 1:
flag = False
break
if flag:
k18_cands.append(k18)
return k18_cands
def search_k20(pairs, pairs2, k18):
k18inv = _inv(k18)
candidates = []
for k20 in range(2**16):
flag = True
for i in range(N):
_, ct = pairs[i]
_, ct_ = pairs2[i]
y1, _, y3, _ = decompose(ct)
y1_, _, y3_, _ = decompose(ct_)
tdiff = _mul(y1, k18inv) ^ _mul(y1_, k18inv)
tdiff2 = _sub(y3, k20) ^ _sub(y3_, k20)
if tdiff != tdiff2:
flag = False
break
if flag:
candidates.append(k20)
return candidates
def check_is_good_pair(pair, pair3, k18inv, k20):
_, ct = pair
_, ct_ = pair3
y1, _, y3, _ = decompose(ct)
y1_, _, y3_, _ = decompose(ct_)
ydiff = _mul(y1, k18inv) ^ _mul(y1_, k18inv)
ydiff2 = _sub(y3, k20) ^ _sub(y3_, k20 ^ 2) # 127のときのk20は差分があるので考慮
# return (ydiff ^ ydiff2) in [0x100, 0x300, 0x700, 0xf00, 0x1f00, 0x3f00, 0x7f00, 0xff00]
return (ydiff ^ ydiff2) in [0x100, 0x300]
def search_k21(good_pairs, k18, k20):
candidates = set()
k19_cands = set()
# good pairs から2個選んでくる。その2個が真にgoodであることを祈りながら
for comb in combinations(good_pairs, 2):
table = {}
for k21 in range(2**16):
k21inv = _inv(k21)
udiffs = []
for _, ct, _, ct_ in comb:
_, _, _, y4 = decompose(ct)
_, _, _, y4_ = decompose(ct_)
udiff = _mul(y4, k21inv) ^ _mul(y4_, k21inv)
udiffs.append(udiff)
key = tuple(udiffs)
table[key] = k21
for k19 in range(2**16):
udiffs = []
for _, ct, _, ct_ in comb:
_, y2, _, _ = decompose(ct)
_, y2_, _, _ = decompose(ct_)
udiff = _sub(y2, k19) ^ _sub(y2_, k19)
udiffs.append(udiff)
key = tuple(udiffs)
if key in table:
k21 = table[key]
candidates.add(k21)
k19_cands.add(k19)
return candidates
def search_k19(pairs, k21):
k21inv = _inv(k21)
counter = Counter()
for k19 in range(2**16):
for _, ct, _, ct_ in pairs:
_, y2, _, y4 = decompose(ct)
_, y2_, _, y4_ = decompose(ct_)
udiff = _mul(y4, k21inv) ^ _mul(y4_, k21inv)
udiff2 = _sub(y2, k19) ^ _sub(y2_, k19)
if udiff == udiff2:
counter[k19] += 1
common = counter.most_common()
res = []
for k, v in common:
if v == common[0][1]:
res.append(k)
else:
break
return res
q: Deque[int] = deque()
sock: Socket = None
def reconnect():
global sock, start
if sock is not None: sock.close()
start = time.time()
q.clear()
sock = Socket("idea.2024.ctfcompetition.com", 1337)
import socket
sock._sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
sock._sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_QUICKACK, 1)
def elapsed(msg):
print(msg, 'elapsed:', time.time() - start)
while True:
try:
print("\n\nstart new search")
reconnect()
def encrypt_send(pt: int):
sock.sendline("1")
sock.sendline(hex(pt))
q.append(pt)
def encrypt_recv():
sock.recvuntil("text: ")
pt = q.popleft()
return pt, int(sock.recvline().strip())
def mask(mask: int):
sock.sendline("2")
sock.sendlineafter("mask: ", hex(mask))
# step1. find k18, k20
N = 20 # step1で使う数
M = 400 # step2で使う数
base = random.randrange(0, 2**48)
pairs = []
for i in range(M):
pt = (random.randrange(0, 2**16) << 48) | base
encrypt_send(pt)
for i in range(M):
pairs.append(encrypt_recv())
elapsed("recv first wave")
mask(2**103)
pairs2 = []
for i in range(N):
pt, ct = pairs[i]
pt_ = pt + 2**(16 + 16 + 7)
encrypt_send(pt_)
for i in range(N):
pairs2.append(encrypt_recv())
elapsed("recv second wave")
k18_cands = search_k18(pairs, pairs2)
elapsed("k18 found")
if len(k18_cands) == 0:
# - のパターンをやる
for i in range(N):
pt, ct = pairs[i]
pt_ = pt - 2**(16 + 16 + 7)
ct_ = encrypt_send(pt_)
for i in range(N):
pairs2[i] = encrypt_recv()
k18_cands = search_k18(pairs, pairs2)
assert len(k18_cands) > 0, "k18 not found"
for k18 in k18_cands:
k20_cands = search_k20(pairs, pairs2, k18)
if len(k20_cands) > 0:
k20 = k20_cands[1] # 決め打ち
break
assert len(k20_cands) > 0, "k20 not found"
elapsed("k20 found")
print("k18:", k18)
print("k20:", k20_cands)
# step2. find k21, k19
mask((2**103) ^ (2**127))
good_pairs = [] # k0 のdiff消せてるペア。偽陽性がある
k18inv = _inv(k18)
for _ in range(M):
pt_ = (random.randrange(0, 2**16) << 48) | base
encrypt_send(pt_)
for _ in range(M):
pt_, ct_ = encrypt_recv()
for i in range(M):
pt, ct = pairs[i]
if check_is_good_pair((pt, ct), (pt_, ct_), k18inv, k20):
good_pairs.append([pt, ct, pt_, ct_])
# あるptに対してgoodなpt_は1つだけしかないのでbreakしてよい
break
if len(good_pairs) >= 10:
break
q.clear() # break するので
assert len(good_pairs) >= 4, "good_pairs not found"
elapsed("good pair found")
k21_cands = search_k21(good_pairs[:5], k18, k20)
assert len(k21_cands) > 0, "k21 not found"
print("k21:", k21_cands)
elapsed("k21 found")
pairs_for_k16 = []
for i in range(N):
_, ct = pairs[i]
_, ct_ = pairs2[i]
y1, y2, y3, y4 = decompose(ct)
y1_, y2_, y3_, y4_ = decompose(ct_)
pairs_for_k16.append([[y1, y2, y3, y4], [y1_, y2_, y3_, y4_]])
import json
args = []
for k21 in k21_cands:
k19_cands = search_k19(good_pairs, k21)
print("k19:", k19_cands)
for k19 in k19_cands:
args.append(json.dumps({
"pairs": pairs_for_k16,
"k18": k18,
"k19": k19,
"k20": k20,
"k21": k21,
}))
print('len:', len(args))
elapsed("start processing")
import subprocess
import concurrent.futures
from multiprocessing import Pool
def solve_k16k17(arg):
output = subprocess.run(["./go_solver/k16k17", arg], capture_output=True)
k16k17 = output.stdout.decode().strip()
if k16k17 != "":
return arg, k16k17
found_k = False
with concurrent.futures.ProcessPoolExecutor() as executor:
futures = {executor.submit(solve_k16k17, arg): arg for arg in args}
for future in concurrent.futures.as_completed(futures):
try:
res = future.result()
print(res)
if res is None: continue
arg, k16k17 = res
found_k = True
elapsed("k16k17 found")
break
except Exception as exc:
print(f"Task {futures[future]} generated an exception: {exc}")
executor.shutdown(wait=False) # 他のプロセスを終了させる
elapsed("end")
if not found_k:
print(":(")
continue
print(arg)
parsed = json.loads(arg)
k19 = parsed["k19"]
k21 = parsed["k21"]
k16, k17 = map(int, k16k17.split())
print("k16k17:", k16k17)
print("k19:", k19)
print("k21:", k21)
pt, ct = pairs[0]
ALL_KEY_PATH = "exploration3/target/release/exploration3"
INPUT3_PATH = "./input3.json"
OUTPUT3_PATH = "./output3.json"
with open(INPUT3_PATH, "w") as fp:
json.dump(
{
# これらがすでに求まっていることを期待
"k16": k16,
"k17": k17,
"k18": k18,
"k19": k19,
"k20": k20,
"k21": k21,
"pt": pt,
"ct": ct, # encrypt(pt) == ct となるペアをひとつ用意 (どちらも int)
},
fp,
)
elapsed("start all key")
p = subprocess.run([ALL_KEY_PATH, INPUT3_PATH, OUTPUT3_PATH])
elapsed("done")
try:
with open(OUTPUT3_PATH, "r") as fp:
output = json.load(fp)
except FileNotFoundError:
print("Fail...")
continue
print(output)
key = hex(output["key"])[2:]
sock.sendlineafter("Get balance", "3")
sock.sendlineafter("key_guess:", key)
sock.interactive()
break
except KeyboardInterrupt:
break
except Exception as e:
print(e)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment