Skip to content

Instantly share code, notes, and snippets.

@mkow
Created January 15, 2022 22:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mkow/b43883ef4c3e25d6001e8c5f4f63430d to your computer and use it in GitHub Desktop.
Save mkow/b43883ef4c3e25d6001e8c5f4f63430d to your computer and use it in GitHub Desktop.
Solver for CRC Recursive Challenge from Dragon CTF 2021
// Solver for CRC Recursive Challenge from Dragon CTF 2021
// Author: Michał "Redford" Kowalczyk
//
// Could probably be much simpler, but this code was evolving for too long
// without a rewrite ¯\_(ツ)_/¯
#include <iostream>
#include <cstdio>
#include <cstddef>
#include <cstring>
#include <cassert>
#include <algorithm>
#include <thread>
using std::swap;
using std::pair;
using std::make_pair;
// using representation without reversing bits: LSB is on the right (x & 1 == LSB)
using poly_t = unsigned __int128;
//using poly_t = uint64_t;
//constexpr poly_t MOD = 0xad93d23594c935a9; // <- picked for 64-bit chall
//constexpr poly_t BITS = 64;
//constexpr poly_t MOD = 0x4caa61b0d7fe5fa54189d46709eaba2d; // <- picked for 128-bit chall, C++ sucks
constexpr poly_t MOD = ((poly_t)0x4caa61b0d7fe5fa5 << 64) + 0x4189d46709eaba2d;
constexpr poly_t BITS = 128;
//const char prefix[] = "My crc64 is 0x";
const char prefix[] = "My crc128 is 0x";
const char suffix[] = "! Cool, isn\'t it?";
constexpr size_t prefix_len = sizeof(prefix) - 1;
constexpr size_t suffix_len = sizeof(suffix) - 1;
constexpr poly_t MASK = ((poly_t)1 << (BITS - 1) << 1) - 1; // can't do '(1 << BITS) - 1' because of the bitshift UB
constexpr size_t SETS_CNT = 2;
poly_t set_fixup_left[SETS_CNT] = {
0x30, // 0 -> [0, 0, 1, 1, 0, 0, 0, 0]
0x60, // 1 -> [0, 1, 1, 0, 0, 0, 0, 0]
};
poly_t set_fixup_right[SETS_CNT] = {
0x0, // 0 -> [0, 0, 0, 0]
0x9, // 1 -> [1, 0, 0, 1]
};
poly_t set_var_coeff_left[SETS_CNT][4] = { // [set][var%4] -> bits if 1
{0x08, // [0, 0, 0, 0, 1, 0, 0, 0]
0x04, // [0, 0, 0, 0, 0, 1, 0, 0]
0x02, // [0, 0, 0, 0, 0, 0, 1, 0]
0x01}, // [0, 0, 0, 0, 0, 0, 0, 1]
{0x00, // [0, 0, 0, 0, 0, 0, 0, 0]
0x04, // [0, 0, 0, 0, 0, 1, 0, 0]
0x02, // [0, 0, 0, 0, 0, 0, 1, 0]
0x01} // [0, 0, 0, 0, 0, 0, 0, 1]
};
poly_t set_var_coeff_right[SETS_CNT][4] = { // [set][var%4] -> bits if 1
{0x08, // [1, 0, 0, 0]
0x04, // [0, 1, 0, 0]
0x02, // [0, 0, 1, 0]
0x01}, // [0, 0, 0, 1]
{0x04, // [0, 1, 0, 0]
0x00, // [0, 0, 0, 0]
0x02, // [0, 0, 1, 0]
0x03} // [0, 0, 1, 1]
};
poly_t crc(const uint8_t* buf, size_t buf_size, poly_t crc = MASK) {
for (size_t i = 0; i < buf_size; i++) {
crc ^= (poly_t)buf[i] << (BITS - 8);
for (int j = 0; j < 8; j++) {
if ((crc & ((poly_t)1 << (BITS - 1))) == 0)
crc = crc << 1;
else
crc = (crc << 1) ^ MOD;
}
}
return crc & MASK;
}
bool verify(poly_t checksum) {
// slow and ugly, but this was used only for testing and debugging
char input[128] = { 0 };
strcat(input, prefix);
char fmt[16];
sprintf(fmt, "%%0%ullx", (int)BITS / 4);
sprintf(input + prefix_len, fmt, checksum);
strcpy(input + prefix_len + BITS / 4, suffix);
return crc((const uint8_t*)input, strlen(input)) == checksum;
}
poly_t poly_bytes_mod(uint8_t* buf, size_t buf_size, poly_t mod) {
poly_t crc = 0;
assert(buf_size > BITS / 8);
size_t i;
for (i = 0; i < buf_size - BITS / 8; i++) {
crc ^= (poly_t)buf[i] << (BITS - 8);
for (int j = 0; j < 8; j++) {
if ((crc & ((poly_t)1 << (BITS - 1))) == 0)
crc = crc << 1;
else
crc = (crc << 1) ^ MOD;
}
}
for (size_t j = 0; j < BITS / 8; j++) {
crc ^= (poly_t)buf[i + j] << (BITS - 8 * (j + 1));
}
return crc & MASK;
}
void poly_bytes_add_bit(uint8_t* buf, size_t buf_size, size_t pos) {
buf[buf_size - 1 - pos / 8] ^= (1 << (pos % 8));
}
// Be careful: modifies `vars`!
//__declspec(noinline)
// TODO: hardcode vars count? - lol, makes the code slower...
void poly_gauss(pair<poly_t, poly_t>* vars /*<weight, var_bit>*/,
size_t vars_cnt, poly_t expected_sum, size_t letters) {
size_t fixed = 0;
for (size_t col = 0; col < BITS; col++) {
bool found = false;
for (size_t row = fixed; row < vars_cnt; row++) {
if (vars[row].first & ((poly_t)1 << col)) {
swap(vars[fixed], vars[row]);
found = true;
break;
}
}
if (!found)
continue; // can't fully control this bit
pair<poly_t, poly_t> cur = vars[fixed];
for (size_t row = fixed + 1; row < vars_cnt; row++) {
if (vars[row].first & ((poly_t)1 << col)) {
vars[row].first ^= cur.first;
vars[row].second ^= cur.second;
}
}
fixed++;
}
poly_t res_sum = 0;
poly_t res_vars = 0;
size_t row = 0;
for (size_t col = 0; col < BITS; col++) {
if ((expected_sum ^ res_sum) & ((poly_t)1 << col)) {
if (vars[row].first & ((poly_t)1 << col)) {
res_sum ^= vars[row].first;
res_vars ^= vars[row].second;
} else {
return;
}
}
if (vars[row].first & ((poly_t)1 << col)) {
row += 1;
if (row >= vars_cnt)
break;
}
}
assert(res_sum == expected_sum); // otherwise we'd have returned in the loop above
assert(fixed > 0);
poly_t limit;
if (vars_cnt - fixed == 0)
limit = 1;
else
limit = ((poly_t)1 << (vars_cnt - fixed - 1)) << 1; // may overflow to 0, but without UB
for (poly_t i = 0;;) {
poly_t solution = res_vars;
for (poly_t bit = 0; bit < vars_cnt - fixed; bit++)
if (i & ((poly_t)1 << bit))
solution ^= vars[vars_cnt - 1 - bit].second;
// verify poly solution (it may not match the bruteforced ranges
bool ok = true;
for (poly_t nibble_pos = 0; nibble_pos < BITS / 4; nibble_pos++) {
int sol_vars = (solution >> (BITS - 4 - 4 * nibble_pos)) & 0xF;
if (letters & ((size_t)1 << nibble_pos)) {
// a-f
// the other bit endianness
// if (sol_vars != 0b1000 &&
// sol_vars != 0b0100 &&
// sol_vars != 0b1101 &&
// sol_vars != 0b0011 &&
// sol_vars != 0b1011 &&
// sol_vars != 0b0111) {
// ok = false;
// break;
// }
if (sol_vars != 0b0001 &&
sol_vars != 0b0010 &&
sol_vars != 0b1011 &&
sol_vars != 0b1100 &&
sol_vars != 0b1101 &&
sol_vars != 0b1110) {
ok = false;
break;
}
} else {
// 0-9
if (sol_vars >= 10) {
ok = false;
break;
}
// the other bit endianness
// if (sol_vars != 0b0000 &&
// sol_vars != 0b1000 &&
// sol_vars != 0b0100 &&
// sol_vars != 0b1100 &&
// sol_vars != 0b0010 &&
// sol_vars != 0b1010 &&
// sol_vars != 0b0110 &&
// sol_vars != 0b1110 &&
// sol_vars != 0b0001 &&
// sol_vars != 0b1001) {
// ok = false;
// break;
// }
}
}
if (ok) {
poly_t crc = 0;
for (poly_t nibble_pos = 0; nibble_pos < BITS / 4; nibble_pos++) {
size_t set_ind = !!(letters & ((size_t)1 << nibble_pos));
poly_t nibble = set_fixup_right[set_ind];
for (int j = 0; j < 4; j++)
if (solution & ((poly_t)1 << (BITS - 1 - nibble_pos * 4 - j)))
nibble ^= set_var_coeff_right[set_ind][j];
crc ^= nibble << (BITS - 4 * (nibble_pos + 1));
}
if (BITS <= 64) // otherwise ull is too small
assert(verify(crc));
if (BITS <= 64) {
char fmt[32];
sprintf(fmt, "Found: 0x%%0%ullx\n", (int)BITS / 4);
printf(fmt, crc);
} else {
char fmt[32];
sprintf(fmt, "Found: 0x%%0%ullx%%0%ullx\n", (int)BITS / 4 / 2, (int)BITS / 4 / 2);
printf(fmt, (uint64_t)(crc >> 64), uint64_t(crc));
}
}
if (i == limit) // doing it this way to prevent UB on 1 << 64
break;
i++;
}
}
poly_t set_weight[BITS / 4][SETS_CNT] = { 0 }; // [nibble_pos][let] -> coeff for C
poly_t var_weight[BITS][SETS_CNT] = { 0 }; // [var][set] -> variable weight
poly_t C_base;
void solver(size_t start, size_t end) { // exclusive
for (size_t letters = start; letters < end; letters++) { // this bitshift won't overflow for the bitsizes we use
if (((letters - start) & 0x1FFFFF) == 0)
//printf("%zu/%zu...\n", letters, end - 1);
printf("%d%%...\n", int((letters - start) / (double)(end-start-1) * 100));
poly_t C = C_base;
for (int i = 0; i < BITS / 4; i++)
C ^= set_weight[i][!!(letters & ((size_t)1 << i))];
size_t vars_cnt = 0;
pair<poly_t, poly_t> vars[BITS]; // <weight, var_bit>
for (int var = 0; var < BITS; var++) {
int set = !!(letters & ((size_t)1 << var / 4));
vars[vars_cnt] = make_pair(var_weight[var][set], (poly_t)1 << (BITS - 1 - var));
vars_cnt++;
}
poly_gauss(vars, vars_cnt, C, letters);
}
}
int main(int argc, char* argv[]) {
// preprocess base C value
uint8_t tmp[prefix_len + BITS / 4 + suffix_len + BITS / 8] = { 0 };
memcpy(tmp, prefix, prefix_len);
memcpy(tmp + prefix_len + BITS / 4, suffix, suffix_len);
for (size_t i = 0; i < BITS / 8; i++)
tmp[i] ^= 0xFF;
C_base = poly_bytes_mod(tmp, sizeof(tmp), MOD);
// preprocess set_weight
for (size_t nibble_pos = 0; nibble_pos < BITS / 4; nibble_pos++)
for (int set = 0; set < SETS_CNT; set++) {
uint8_t tmp[BITS / 4 + suffix_len + BITS / 8] = { 0 };
tmp[nibble_pos] ^= set_fixup_left[set];
if (nibble_pos % 2 == 0)
tmp[BITS / 4 + suffix_len + nibble_pos / 2] ^= (set_fixup_right[set] << 4);
else
tmp[BITS / 4 + suffix_len + nibble_pos / 2] ^= set_fixup_right[set];
set_weight[nibble_pos][set] = poly_bytes_mod(tmp, sizeof(tmp), MOD);
}
// preprocess var_weight
for (size_t var = 0; var < BITS; var++)
for (int set = 0; set < SETS_CNT; set++) {
uint8_t tmp[BITS / 4 + suffix_len + BITS / 8] = { 0 };
tmp[var / 4] ^= set_var_coeff_left[set][var % 4];
if ((var / 4) % 2 == 0)
tmp[BITS / 4 + suffix_len + var / 4 / 2] ^= (set_var_coeff_right[set][var % 4] << 4);
else
tmp[BITS / 4 + suffix_len + var / 4 / 2] ^= set_var_coeff_right[set][var % 4];
var_weight[var][set] = poly_bytes_mod(tmp, sizeof(tmp), MOD);
}
size_t threads_cnt;
if (argc <= 1) {
threads_cnt = 1;
} else {
threads_cnt = atoi(argv[1]);
}
printf("Running main loop with %zu threads...\n", threads_cnt);
std::vector<std::thread> threads;
size_t end = (size_t)1 << BITS / 4;
size_t step = (end + threads_cnt - 1) / threads_cnt;
for (size_t i = 0; i < threads_cnt; i++)
threads.emplace_back(solver, step * i, std::min(step * (i + 1), end));
for (auto& thread : threads)
thread.join();
puts("Done!");
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment