Create a gist now

Instantly share code, notes, and snippets.

@adamh-basis /Solution.java Secret
Last active Dec 15, 2015

What would you like to do?
Response to "Java Puzzle: Square Root"
import java.lang.reflect.Field;
import java.math.BigInteger;
import java.math.BigDecimal;
import java.math.MathContext;
import java.math.RoundingMode;
import java.util.Arrays;
import square.SquareRoot;
/**
* Essentially, this is just a side channel attack taking advantage of the fact
* that there's a fast path in division when the dividend <= divisor. It
* does a binary search, using a "fast" calculation as an indication that the
* guess is too high, and a "slow" calculation as an indication that the guess
* is too low.
* <p>
* This allows guessing the square, but I still have to find the root, so it
* uses Newton's method for that part.
* <p>
* All in all, the hardest part by far was trying to determine a simple test
* statistic to use to ensure there's a reasonably high probability of
* correctly distinguishing between a "fast" answer and a "slow" answer. This
* test was robust on two distinct platforms.
* <p>
* This does take advantage of setAccessible only for verification, it will run
* correctly with a security manager, however.
* <p>
* I'm interested to know if there was another way to do this; this seems like
* a not very numerical approach to the problem. Fun problems like this are
* hard to find time to solve, this one was just right in terms of complexity.
*
* @author adamh@basis.com
*/
public class Solution {
public static final BigInteger TWO = BigInteger.valueOf(2);
public static final int INNER = 100;
public static final int OUTER = 10;
public static final int WARMUP = 1000;
public static void main(String[] argv) throws Throwable {
// Fill up a byte[] that will be bigger than N. sqrt(N) is from
// 0 .. 10000 bits, so N^2 could be 20000 bits. I make my array big
// enough to hold 20001 bits, and set the msb.
byte[] bits = new byte[(SquareRoot.BITS * 2 + 1 + 7) / 8];
int msbBits = (SquareRoot.BITS * 2 + 1) & 7;
long[] values = new long[OUTER];
bits[0] = (byte)(0x100 >>> (8 - msbBits + 1));
// Start with a number that is definitely bigger than 'n'.
BigInteger cand = new BigInteger(bits);
// Prime the JIT and get some stats so we have somewhat
// consistent results
BigInteger t1 = cand;
BigInteger t2 = BigInteger.valueOf(Long.MAX_VALUE);
double sum = 0;
double sumSq = 0;
long n = 0;
for (int i = 0; i < 1000; ++i) {
runTest(t1, values);
for (int j = 0; j < values.length; ++j) {
n++;
sum += values[j];
sumSq += values[j] * values[j];
}
runTest(t2, values);
t1 = t1.add(BigInteger.ONE);
t2 = t2.add(BigInteger.ONE);
}
double mean = sum / n;
double svar = (sumSq - n * mean * mean) / (n - 1);
double sdev = Math.sqrt(svar);
// If we're no more than 1 sigma above the mean, we'll say our number is
// bigger than N. If our number is more than 2 sigma above the mean,
// we'll say our number is smaller than N. Other than that, it's hard
// to tell, so we'll try again.
double[] maxdelta = new double[] { sdev, 2*sdev };
// Now, armed with maths, we can find the number
search(values, mean, maxdelta, cand);
}
static int times = 0;
private static void search(long[] values,
double mean,
double[] maxdelta,
BigInteger cand) throws Throwable {
BigInteger hi = cand.multiply(TWO);
BigInteger lo = BigInteger.ZERO;
for (int j = 0; hi.subtract(lo).signum() != 0; ++j) {
runTest(cand, values);
Boolean low = checkCand(cand);
long med = values[values.length / 2];
// maxdelta[0] is 1 sigma
boolean hiHit = med < mean + maxdelta[0];
// maxdelta[1] is 2 sigma
boolean loHit = med > mean + maxdelta[1];
if (hiHit) {
hi = cand;
if (low != null && low) {
throw new Error(String.format("Error, assumed cand >= n: j=%d, med=%.2f, mean=%.2f, delta=%.2f%n", j, (double)med, mean, (med - mean) / mean));
}
times = 0;
} else if (loHit) {
// Too far, back it up
lo = cand.add(BigInteger.ONE);
if (low != null && !low) {
throw new Error(String.format("Error, assumed cand < n: j=%d, med=%.2f, mean=%.2f, delta=%.2f%n", j, (double)med, mean, (med - mean) / mean));
}
times = 0;
} else {
// It's possible we won't hit it, in that case, we'll fail.
times++;
if (times > 100)
{
throw new Error("failing, it's not going to work");
}
System.out.printf("Outside sigma=1 range, actual=%s: hiavg=%.2f, med=%.2f%n",
(low == null)? "??" : (low? "cand < n" : "cand >= n"),
mean,
(double)med);
}
// If we missed somehow, we don't want to go forever.
BigInteger test = hi.subtract(lo);
if (test.signum() < 0) {
hi = lo;
}
// binary search, assuming that
cand = lo.add(hi.subtract(lo).shiftRight(1));
}
findSquareRoot(cand);
}
private static void runTest(BigInteger cand, long[] p_values) {
long st, end;
for (int i = 0; i < OUTER; ++i) {
st = System.nanoTime();
for (int k = 0; k < INNER; ++k) {
SquareRoot.answer(cand);
}
end = System.nanoTime();
p_values[i] = (end - st);
}
Arrays.sort(p_values);
}
private static void findSquareRoot(BigInteger p_cand) throws Throwable{
// Now we have n, we need to find the square root.
// Guess cand >> (log2(cand) + 1), iterate w/ Newton's method
BigDecimal cand = new BigDecimal(p_cand);
BigDecimal r = new BigDecimal(p_cand.shiftRight(p_cand.bitLength() / 2 + 1));
BigDecimal np = r.multiply(r);
while (! np.round(new MathContext(0)).equals(cand.round(new MathContext(0)))) {
BigDecimal two_r = r.add(r);
r = r.subtract(np.subtract(cand).divide(two_r, RoundingMode.HALF_EVEN));
np = r.multiply(r);
}
// This should be it!
SquareRoot.answer(r.toBigInteger());
}
static BigInteger N;
private static Boolean checkCand(BigInteger p_cand) {
try {
if (System.getSecurityManager() == null) {
if (N == null) {
Field f = SquareRoot.class.getDeclaredField("n");
f.setAccessible(true);
N = (BigInteger)f.get(null);
System.out.println("n.bitLength()=" + N.bitLength());
}
if (N.equals(p_cand)) {
System.out.println("cand == n!");
}
return N.compareTo(p_cand) > 0;
}
} catch (Exception e) {
e.printStackTrace();
}
return null;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment