Instantly share code, notes, and snippets.

Last active December 15, 2015 07:49
Show Gist options
• Save adamh-basis/f1d887e21d7237565198 to your computer and use it in GitHub Desktop.
Response to "Java Puzzle: Square Root"
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters
 import java.lang.reflect.Field; import java.math.BigInteger; import java.math.BigDecimal; import java.math.MathContext; import java.math.RoundingMode; import java.util.Arrays; import square.SquareRoot; /** * Essentially, this is just a side channel attack taking advantage of the fact * that there's a fast path in division when the dividend <= divisor. It * does a binary search, using a "fast" calculation as an indication that the * guess is too high, and a "slow" calculation as an indication that the guess * is too low. *

* This allows guessing the square, but I still have to find the root, so it * uses Newton's method for that part. *

* All in all, the hardest part by far was trying to determine a simple test * statistic to use to ensure there's a reasonably high probability of * correctly distinguishing between a "fast" answer and a "slow" answer. This * test was robust on two distinct platforms. *

* This does take advantage of setAccessible only for verification, it will run * correctly with a security manager, however. *

* I'm interested to know if there was another way to do this; this seems like * a not very numerical approach to the problem. Fun problems like this are * hard to find time to solve, this one was just right in terms of complexity. * * @author adamh@basis.com */ public class Solution { public static final BigInteger TWO = BigInteger.valueOf(2); public static final int INNER = 100; public static final int OUTER = 10; public static final int WARMUP = 1000; public static void main(String[] argv) throws Throwable { // Fill up a byte[] that will be bigger than N. sqrt(N) is from // 0 .. 10000 bits, so N^2 could be 20000 bits. I make my array big // enough to hold 20001 bits, and set the msb. byte[] bits = new byte[(SquareRoot.BITS * 2 + 1 + 7) / 8]; int msbBits = (SquareRoot.BITS * 2 + 1) & 7; long[] values = new long[OUTER]; bits[0] = (byte)(0x100 >>> (8 - msbBits + 1)); // Start with a number that is definitely bigger than 'n'. BigInteger cand = new BigInteger(bits); // Prime the JIT and get some stats so we have somewhat // consistent results BigInteger t1 = cand; BigInteger t2 = BigInteger.valueOf(Long.MAX_VALUE); double sum = 0; double sumSq = 0; long n = 0; for (int i = 0; i < 1000; ++i) { runTest(t1, values); for (int j = 0; j < values.length; ++j) { n++; sum += values[j]; sumSq += values[j] * values[j]; } runTest(t2, values); t1 = t1.add(BigInteger.ONE); t2 = t2.add(BigInteger.ONE); } double mean = sum / n; double svar = (sumSq - n * mean * mean) / (n - 1); double sdev = Math.sqrt(svar); // If we're no more than 1 sigma above the mean, we'll say our number is // bigger than N. If our number is more than 2 sigma above the mean, // we'll say our number is smaller than N. Other than that, it's hard // to tell, so we'll try again. double[] maxdelta = new double[] { sdev, 2*sdev }; // Now, armed with maths, we can find the number search(values, mean, maxdelta, cand); } static int times = 0; private static void search(long[] values, double mean, double[] maxdelta, BigInteger cand) throws Throwable { BigInteger hi = cand.multiply(TWO); BigInteger lo = BigInteger.ZERO; for (int j = 0; hi.subtract(lo).signum() != 0; ++j) { runTest(cand, values); Boolean low = checkCand(cand); long med = values[values.length / 2]; // maxdelta[0] is 1 sigma boolean hiHit = med < mean + maxdelta[0]; // maxdelta[1] is 2 sigma boolean loHit = med > mean + maxdelta[1]; if (hiHit) { hi = cand; if (low != null && low) { throw new Error(String.format("Error, assumed cand >= n: j=%d, med=%.2f, mean=%.2f, delta=%.2f%n", j, (double)med, mean, (med - mean) / mean)); } times = 0; } else if (loHit) { // Too far, back it up lo = cand.add(BigInteger.ONE); if (low != null && !low) { throw new Error(String.format("Error, assumed cand < n: j=%d, med=%.2f, mean=%.2f, delta=%.2f%n", j, (double)med, mean, (med - mean) / mean)); } times = 0; } else { // It's possible we won't hit it, in that case, we'll fail. times++; if (times > 100) { throw new Error("failing, it's not going to work"); } System.out.printf("Outside sigma=1 range, actual=%s: hiavg=%.2f, med=%.2f%n", (low == null)? "??" : (low? "cand < n" : "cand >= n"), mean, (double)med); } // If we missed somehow, we don't want to go forever. BigInteger test = hi.subtract(lo); if (test.signum() < 0) { hi = lo; } // binary search, assuming that cand = lo.add(hi.subtract(lo).shiftRight(1)); } findSquareRoot(cand); } private static void runTest(BigInteger cand, long[] p_values) { long st, end; for (int i = 0; i < OUTER; ++i) { st = System.nanoTime(); for (int k = 0; k < INNER; ++k) { SquareRoot.answer(cand); } end = System.nanoTime(); p_values[i] = (end - st); } Arrays.sort(p_values); } private static void findSquareRoot(BigInteger p_cand) throws Throwable{ // Now we have n, we need to find the square root. // Guess cand >> (log2(cand) + 1), iterate w/ Newton's method BigDecimal cand = new BigDecimal(p_cand); BigDecimal r = new BigDecimal(p_cand.shiftRight(p_cand.bitLength() / 2 + 1)); BigDecimal np = r.multiply(r); while (! np.round(new MathContext(0)).equals(cand.round(new MathContext(0)))) { BigDecimal two_r = r.add(r); r = r.subtract(np.subtract(cand).divide(two_r, RoundingMode.HALF_EVEN)); np = r.multiply(r); } // This should be it! SquareRoot.answer(r.toBigInteger()); } static BigInteger N; private static Boolean checkCand(BigInteger p_cand) { try { if (System.getSecurityManager() == null) { if (N == null) { Field f = SquareRoot.class.getDeclaredField("n"); f.setAccessible(true); N = (BigInteger)f.get(null); System.out.println("n.bitLength()=" + N.bitLength()); } if (N.equals(p_cand)) { System.out.println("cand == n!"); } return N.compareTo(p_cand) > 0; } } catch (Exception e) { e.printStackTrace(); } return null; } }