-
-
Save mmastrac/f58353ca1c1140267740 to your computer and use it in GitHub Desktop.
Improve detection of lower bits.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package square; | |
import java.math.BigInteger; | |
import java.util.concurrent.CountDownLatch; | |
import java.util.concurrent.TimeUnit; | |
/** | |
* Attempt to guess SquareRoot.n using a timing attack on BigInteger.divide(). | |
* Assuming that the timing is relatively stable, finds the answer reasonably | |
* quickly. Conveniently also O(log(N)). | |
*/ | |
public class SquareRootSolver { | |
/** | |
* The factor difference between two successive numbers to be considered the | |
* point where we go from >n to <n. | |
*/ | |
private static final int DIFFERENCE = 2; | |
/** | |
* The number of trials to use to determine how long it takes to divide. | |
* This should be increased for computers faster than a 2.3GHz i7 MBP. | |
*/ | |
private static final int TRIALS = 100; | |
/** | |
* The number of trials to use to determine how long it takes to divide, | |
* when we're testing the smallest bits. | |
*/ | |
private static final int TRIALS_SMALL = 10000; | |
/** | |
* The number of trials to run. | |
*/ | |
private static final int RUNS = 5; | |
/** | |
* The number of warmup calls. | |
*/ | |
private static final int WARMUP = 25; | |
// Stolen from Caliper | |
public static void forceGc() { | |
System.gc(); | |
System.runFinalization(); | |
final CountDownLatch latch = new CountDownLatch(1); | |
new Object() { | |
@Override | |
protected void finalize() { | |
latch.countDown(); | |
} | |
}; | |
System.gc(); | |
System.runFinalization(); | |
try { | |
latch.await(2, TimeUnit.SECONDS); | |
} catch (InterruptedException e) { | |
Thread.currentThread().interrupt(); | |
} | |
} | |
// Stolen from | |
// http://stackoverflow.com/questions/4407839/how-can-i-find-the-square-root-of-a-java-biginteger | |
public static BigInteger squareRoot(BigInteger x) throws IllegalArgumentException { | |
if (x.equals(BigInteger.ZERO) || x.equals(BigInteger.ONE)) | |
return x; | |
BigInteger two = BigInteger.valueOf(2L); | |
BigInteger y; | |
for (y = x.divide(two); y.compareTo(x.divide(y)) > 0; y = ((x.divide(y)).add(y)).divide(two)) | |
; | |
return y; | |
} | |
/** | |
* Run a number of trials to get a good approximation of the time. It would | |
* be better to use Real Statistics(tm) here, but this works well enough. | |
*/ | |
private static long test(BigInteger t, boolean small) { | |
long[] times = new long[RUNS]; | |
long total; | |
// Warm up the JVM a bit | |
for (int i = 0; i < WARMUP; i++) { | |
SquareRoot.answer(t); | |
} | |
// Run more trials if we're closer to the end of the number, as these | |
// bits tend to be tougher to get with the timing attack. | |
final int trials = small ? TRIALS_SMALL : TRIALS; | |
do { | |
total = 0; | |
// Better solution would be to use statistics, or possibly | |
// oversample and throw out outliers. This works, though. | |
for (int run = 0; run < RUNS; run++) { | |
long start = System.nanoTime(); | |
for (int i = 0; i < trials; i++) { | |
SquareRoot.answer(t); | |
} | |
times[run] = System.nanoTime() - start; | |
total += times[run]; | |
} | |
} while (!similar(times, total)); | |
return total; | |
} | |
/** | |
* Instead of digging out the statistics textbook, here's a hacky way to | |
* determine if we have any outliers. | |
*/ | |
private static boolean similar(long[] times, long total) { | |
long avg = total / times.length; | |
double difference = avg * 0.5; | |
for (long time : times) { | |
if (Math.abs(time - avg) > difference) { | |
return false; | |
} | |
} | |
return true; | |
} | |
/** | |
* Dividing n by a number < n is more than twice as slow as dividing a | |
* number > n. Exploit this. | |
*/ | |
private static int findNextBit(BigInteger known, BigInteger n) { | |
n = n.subtract(BigInteger.ONE); | |
long lastTime = Long.MAX_VALUE; | |
int i = n.bitLength(); | |
boolean small = i < 16; | |
BigInteger t; | |
while (true) { | |
t = known.or(n); | |
long time = test(t, small); | |
if (time / lastTime > DIFFERENCE) { | |
return i; | |
} | |
i--; | |
lastTime = time; | |
if (n.signum() == 0) | |
break; | |
n = n.shiftRight(1); | |
} | |
return -1; | |
} | |
/** | |
* Run the findNextBit method a few times in a row to reduce the probability | |
* that we get the wrong answer. | |
*/ | |
private static int findLikelyNextBit(BigInteger known, BigInteger lastKnown) { | |
int count = 0; | |
int answer = Integer.MAX_VALUE; | |
while (true) { | |
forceGc(); | |
int result = findNextBit(known, lastKnown); | |
if (result == answer) { | |
count++; | |
// If we didn't find the next bit, let's be much more sure about | |
// it. When the timing attack fails, it usually fails to find | |
// the next bit. | |
if (answer == -1 && count == 10) | |
break; | |
// Otherwise we need fewer agreements in the answers. | |
if (answer != -1 && count == 2) | |
break; | |
} else { | |
answer = result; | |
count = 0; | |
} | |
} | |
return answer; | |
} | |
public static void main(String[] args) { | |
BigInteger known = BigInteger.ZERO; | |
BigInteger lastKnown = BigInteger.ONE.shiftLeft(SquareRoot.BITS * 2 + 1); | |
int i = 0; | |
while (true) { | |
if (i++ % 10 == 0) { | |
System.out.println("K: " + known.toString(16)); | |
} | |
int next = findLikelyNextBit(known, lastKnown); | |
if (next == -1) | |
break; | |
lastKnown = BigInteger.ONE.shiftLeft(next); | |
known = known.or(lastKnown); | |
} | |
System.out.println("K: " + known.toString(16)); | |
// Now that we have our answer, start with the square root. | |
BigInteger root = squareRoot(known); | |
System.out.println("R: " + root.toString(16)); | |
SquareRoot.answer(root); | |
// Sometimes we don't catch the low bit(s) properly, so cast a slightly | |
// wider net. | |
SquareRoot.answer(root.add(BigInteger.ONE)); | |
SquareRoot.answer(root.subtract(BigInteger.ONE)); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thanks! Most of this code is just various ways of running the timing test over and over to ensure that we eliminate false positives thanks to all the various sources of jitter in the system (ie: GC, multitasking overhead, JIT warmup, etc).
The core loop basically looks for 1 bits by ORing the known bits with the test bits (1s in all the unknown positions) then shifting the unknown part right by one each time through the loop. When we find a timing change (2x slower), we've gone from larger to small and we know we've found one more bit.