Create a gist now

Instantly share code, notes, and snippets.

@hellman /code.cpp
Last active Sep 4, 2017

What would you like to do?
TWCTF 2017 - Palindrome Pairs - Challenge Phase
#include <iostream>
#include <stdlib.h>
using namespace std;
#define REP(i,x) for(int i = 0; i < (int)x; i++)
#define M 8
int N;
string s[1000];
long q[M], p[M], hs[M][1000], hr[M][1000];
long mp(long a, long b, long c){
long r=1;
for(;b;b>>=1){
if(b&1) r=r*a%c;
a=a*a%c;
}
return r;
}
int main() {
std::ios_base::sync_with_stdio(false);
REP(i, M) {
q[i]=rand();
p[i]=rand();
}
cin>>N;
REP(i,N) cin>>s[i];
REP(i,N){
REP(j,M){
REP(k,s[i].size())hs[j][i]=(hs[j][i]*q[j]+s[i][k])%p[j];
REP(k,s[i].size())hr[j][i]=(hr[j][i]*q[j]+s[i][s[i].size()-k-1])%p[j];
}
}
long ans=0;
REP(i,N){
REP(j,N){
bool o=true;
REP(k,M){
if(
(hs[k][i]*mp(q[k],s[j].size(),p[k])+hs[k][j]
-hr[k][j]*mp(q[k],s[i].size(),p[k])-hr[k][i])%p[k]){
o=false;break;
}
}
if (o) ans++;
}
}
cout<<ans<<endl;
}
'''
In this challenge we are given a C code computing a number of palindromic pairs in the input.
We need to find an input on which the program is wrong or crashes.
The program uses Rabin-Karp rolling hash modulo 8 different random primes to check for palindromes.
Notable, it uses rand() function without seeding. Therefore the primes are actually fixed!
This is what we are going to exploit!
We will build a string such that its hash is equal to the hash of its reverse.
Consider even-length string s of length L.
Consider pairs of chars (s[i], s[-1-i]).
Each such pair contributes
(s[i] - s[-1-i]) * (q**i - q**(L-1-i))
to the hash difference H(s) - H(rev(s)), which we want to be zero.
That is, only the difference (s[i] - s[-1-i]) is important.
Due to the charset requirement, we can get difference between -25 and 25 only.
Therefore, we need to find a linear combination of the terms (q**i - q**(L-1-i))
with small coefficients and which sums to zero.
LLL is good for this!
We build a matrix (rows are the basis of the lattice):
(8 terms for i=0) | 1 0 0 ... 0
(8 terms for i=1) | 0 1 0 ... 0
...
(8 terms for i=L/2-1) | 0 1 0 ... 0
p0 | 0 ... 0
p1 | 0 ... 0
... | 0 ... 0
p8 | 0 ... 0
We want a vector in this lattice with first 8 zero values (hash difference is zero),
and the rest of the values being small (between -25 and 25).
Since the hash difference must be exactly zero (small difference would not fit us),
we should switch to orthogonal lattice. However, we can use a small trick to go without this.
We can multiply the left half of the matrix by a large number,
thus telling the LLL algorithm that minimizing the hash differences is priority.
This works and for L=100 we can already find lots of solutions!
Pseudo-palindrome:
aalnjdcagaabafaasaaahiaaaaabaaaeaoeaaajnkjacfbjfaajnaaaaaaaaaadkfaaeahgiapdccoaaifcahgarafkamaaaaaaa
To make a challenge input we just split it in a random place:
$ cat >hack.in
2
aal njdcagaabafaasaaahiaaaaabaaaeaoeaaajnkjacfbjfaajnaaaaaaaaaadkfaaeahgiapdccoaaifcahgarafkamaaaaaaa
The challenge program returns 2, while the answer is 0!
The flags:
TWCTF{Simple_Rolling_Hash_is_not_suitable_to_principal_part}
TWCTF{Sorry_for_participants-rolling_hash_is_weak}
(the challenge was fixed and another flag was added)
'''
from sage.all import *
qps = [
(0x6b8b4567, 0x327b23c6),
(0x643c9869, 0x66334873),
(0x74b0dc51, 0x19495cff),
(0x2ae8944a, 0x625558ec),
(0x238e1f29, 0x46e87ccd),
(0x3d1b58ba, 0x507ed7ab),
(0x2eb141f2, 0x41b71efb),
(0x79e2a9e3, 0x7545e146),
]
qs = [q for q, p in qps]
ps = [p for q, p in qps]
M = len(qps)
def h(s):
"""polynomial hash modulo 8 primes"""
v = [0] * 8
for c in s:
v = [x * q + ord(c) for x, q in zip(v, qs)]
return tuple(v % p for v, p in zip(v, ps))
L = 100
N = L / 2
MULTIPLIER = 100
mv = matrix(ZZ, N, N)
for y in xrange(N):
for x, q, p in zip(range(8), qs, ps):
mv[y,x] = (pow(q, y, p) - pow(q, L-1-y, p)) % p
m = matrix(ZZ, N + 8, N + 8)
# submatrix with terms
m.set_block(0, 0, MULTIPLIER * mv)
# modulo reductions
m.set_block(N, 0, MULTIPLIER * diagonal_matrix(ps))
# term coefficients
m.set_block(0, 8, identity_matrix(N))
# 4th submiatrix is zero
for row in m.LLL():
print row[:8], min(row[8:]), "~", max(row[8:])
delta = max(abs(v) for v in row[8:])
if set(row[:8]) == {0} and delta <= 25:
print "Good!"
s = [None] * L
for i, v in enumerate(row[8:]):
a = ord('a')
b = a + abs(v)
if v > 0:
a, b = b, a
s[i] = a
s[-1-i] = b
s = "".join(map(chr, s))
print s
print h(s)
print h(s[::-1])
assert h(s) == h(s[::-1])
break
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment