Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
Improve detection of lower bits.
package square;
import java.math.BigInteger;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
/**
* Attempt to guess SquareRoot.n using a timing attack on BigInteger.divide().
* Assuming that the timing is relatively stable, finds the answer reasonably
* quickly. Conveniently also O(log(N)).
*/
public class SquareRootSolver {
/**
* The factor difference between two successive numbers to be considered the
* point where we go from >n to <n.
*/
private static final int DIFFERENCE = 2;
/**
* The number of trials to use to determine how long it takes to divide.
* This should be increased for computers faster than a 2.3GHz i7 MBP.
*/
private static final int TRIALS = 100;
/**
* The number of trials to use to determine how long it takes to divide,
* when we're testing the smallest bits.
*/
private static final int TRIALS_SMALL = 10000;
/**
* The number of trials to run.
*/
private static final int RUNS = 5;
/**
* The number of warmup calls.
*/
private static final int WARMUP = 25;
// Stolen from Caliper
public static void forceGc() {
System.gc();
System.runFinalization();
final CountDownLatch latch = new CountDownLatch(1);
new Object() {
@Override
protected void finalize() {
latch.countDown();
}
};
System.gc();
System.runFinalization();
try {
latch.await(2, TimeUnit.SECONDS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
// Stolen from
// http://stackoverflow.com/questions/4407839/how-can-i-find-the-square-root-of-a-java-biginteger
public static BigInteger squareRoot(BigInteger x) throws IllegalArgumentException {
if (x.equals(BigInteger.ZERO) || x.equals(BigInteger.ONE))
return x;
BigInteger two = BigInteger.valueOf(2L);
BigInteger y;
for (y = x.divide(two); y.compareTo(x.divide(y)) > 0; y = ((x.divide(y)).add(y)).divide(two))
;
return y;
}
/**
* Run a number of trials to get a good approximation of the time. It would
* be better to use Real Statistics(tm) here, but this works well enough.
*/
private static long test(BigInteger t, boolean small) {
long[] times = new long[RUNS];
long total;
// Warm up the JVM a bit
for (int i = 0; i < WARMUP; i++) {
SquareRoot.answer(t);
}
// Run more trials if we're closer to the end of the number, as these
// bits tend to be tougher to get with the timing attack.
final int trials = small ? TRIALS_SMALL : TRIALS;
do {
total = 0;
// Better solution would be to use statistics, or possibly
// oversample and throw out outliers. This works, though.
for (int run = 0; run < RUNS; run++) {
long start = System.nanoTime();
for (int i = 0; i < trials; i++) {
SquareRoot.answer(t);
}
times[run] = System.nanoTime() - start;
total += times[run];
}
} while (!similar(times, total));
return total;
}
/**
* Instead of digging out the statistics textbook, here's a hacky way to
* determine if we have any outliers.
*/
private static boolean similar(long[] times, long total) {
long avg = total / times.length;
double difference = avg * 0.5;
for (long time : times) {
if (Math.abs(time - avg) > difference) {
return false;
}
}
return true;
}
/**
* Dividing n by a number < n is more than twice as slow as dividing a
* number > n. Exploit this.
*/
private static int findNextBit(BigInteger known, BigInteger n) {
n = n.subtract(BigInteger.ONE);
long lastTime = Long.MAX_VALUE;
int i = n.bitLength();
boolean small = i < 16;
BigInteger t;
while (true) {
t = known.or(n);
long time = test(t, small);
if (time / lastTime > DIFFERENCE) {
return i;
}
i--;
lastTime = time;
if (n.signum() == 0)
break;
n = n.shiftRight(1);
}
return -1;
}
/**
* Run the findNextBit method a few times in a row to reduce the probability
* that we get the wrong answer.
*/
private static int findLikelyNextBit(BigInteger known, BigInteger lastKnown) {
int count = 0;
int answer = Integer.MAX_VALUE;
while (true) {
forceGc();
int result = findNextBit(known, lastKnown);
if (result == answer) {
count++;
// If we didn't find the next bit, let's be much more sure about
// it. When the timing attack fails, it usually fails to find
// the next bit.
if (answer == -1 && count == 10)
break;
// Otherwise we need fewer agreements in the answers.
if (answer != -1 && count == 2)
break;
} else {
answer = result;
count = 0;
}
}
return answer;
}
public static void main(String[] args) {
BigInteger known = BigInteger.ZERO;
BigInteger lastKnown = BigInteger.ONE.shiftLeft(SquareRoot.BITS * 2 + 1);
int i = 0;
while (true) {
if (i++ % 10 == 0) {
System.out.println("K: " + known.toString(16));
}
int next = findLikelyNextBit(known, lastKnown);
if (next == -1)
break;
lastKnown = BigInteger.ONE.shiftLeft(next);
known = known.or(lastKnown);
}
System.out.println("K: " + known.toString(16));
// Now that we have our answer, start with the square root.
BigInteger root = squareRoot(known);
System.out.println("R: " + root.toString(16));
SquareRoot.answer(root);
// Sometimes we don't catch the low bit(s) properly, so cast a slightly
// wider net.
SquareRoot.answer(root.add(BigInteger.ONE));
SquareRoot.answer(root.subtract(BigInteger.ONE));
}
}

Wow! I'm just trying to soak all this in. Very impressive!

Owner

mmastrac commented Mar 25, 2013

Thanks! Most of this code is just various ways of running the timing test over and over to ensure that we eliminate false positives thanks to all the various sources of jitter in the system (ie: GC, multitasking overhead, JIT warmup, etc).

The core loop basically looks for 1 bits by ORing the known bits with the test bits (1s in all the unknown positions) then shifting the unknown part right by one each time through the loop. When we find a timing change (2x slower), we've gone from larger to small and we know we've found one more bit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment