Create a gist now

Instantly share code, notes, and snippets.

anonymous /JavaPuzzle.java Secret
Created Mar 22, 2013

What would you like to do?
Solver for http://corner.squareup.com/2013/03/puzzle-square-root.html : timing attack, answer(N) will take much longer if N is less than (SquareRoot.n); binary search for the smallest N that is 'fast', which probably equals SquareRoot.n.
import java.math.BigInteger;
import java.security.SecureRandom;
public class JavaPuzzle {
private static final int REPS = 1000;
public static void main(String[] args) throws Exception {
// Timing attack: takes advantage of the fact that a/b for BigInteger
// special-cases a <= b and is an order of magnitude faster.
// SquareRoot.answer calls n/root; use binary search to find the smallest
// 'root' that is fast; that's n^2
// Need to do some timings to determine what 'fast' is
BigInteger smallRoot = new BigInteger(SquareRoot.BITS - 10, new SecureRandom()).abs();
BigInteger bigRoot = new BigInteger(SquareRoot.BITS + 10, new SecureRandom()).abs();
if (smallRoot.compareTo(bigRoot) >= 0) {
throw new IllegalStateException("Bad small/big roots: " + smallRoot + ", " + bigRoot);
}
int warmups = 1000;
for (int i = 0; i < warmups; i++) {
SquareRoot.answer(smallRoot);
SquareRoot.answer(bigRoot);
SquareRoot.answer(smallRoot.pow(2));
SquareRoot.answer(bigRoot.pow(2));
}
long shortTime = time(bigRoot.pow(2), REPS);
long longTime = time(smallRoot.pow(2), REPS);
if (longTime < shortTime * 5) {
throw new IllegalStateException("Not enough difference between short and long: "
+ shortTime + ", " + longTime);
}
System.out.printf("Long: %s, short: %s%n", longTime, shortTime);
BigInteger one = BigInteger.valueOf(1);
BigInteger two = BigInteger.valueOf(2);
// start at smallRoot, keep doubling it until it takes a short time
bigRoot = smallRoot;
do {
smallRoot = bigRoot;
bigRoot = bigRoot.multiply(two);
} while (time(bigRoot.pow(2), REPS) > shortTime * 3);
// if we see lots of "last guess was too low/high" in a row, we probably
// took a wrong turn somewhere
System.out.printf("Have bounds: small = %s (%s bits), big = %s (%s bits), diff = %s bits%n",
time(smallRoot.pow(2), REPS), smallRoot.bitLength(),
time(bigRoot.pow(2), REPS), bigRoot.bitLength(),
bigRoot.subtract(smallRoot).bitLength());
String whichMiss;
while (bigRoot.subtract(smallRoot).compareTo(one) > 0) {
BigInteger midPoint = bigRoot.add(smallRoot).divide(two);
long time = time(midPoint.pow(2), REPS);
if (time > shortTime * 2) {
smallRoot = midPoint;
whichMiss = "low";
} else {
bigRoot = midPoint;
whichMiss = "high";
}
System.out.printf("Difference is %s bits; last guess was too %s%n",
bigRoot.subtract(smallRoot).bitLength(), whichMiss);
}
System.out.printf("Finished; suspected root is %s%n", bigRoot);
System.out.printf("Original short/long times were (%s, %s)%n", shortTime, longTime);
System.out.printf("Time for suspected root^2 is %s%n", time(bigRoot.pow(2), REPS));
System.out.printf("Time for suspected (root-10)^2 is %s%n",
time(bigRoot.subtract(BigInteger.valueOf(10)).pow(2), REPS));
SquareRoot.answer(bigRoot);
}
private static long time(BigInteger b, int reps) {
// don't get fooled by a gc that happens while we're timing
System.gc();
long start = System.nanoTime();
for (int i = 0; i < reps; i++) {
SquareRoot.answer(b);
}
return System.nanoTime() - start;
}
static class SquareRoot {
public static final int BITS = 10000;
private static BigInteger n = new BigInteger(BITS, new SecureRandom()).pow(2);
public static void answer(BigInteger root) {
if (n.divide(root).equals(root)) {
// The goal is to reach this line
System.out.println("Square root!");
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment