Skip to content

Instantly share code, notes, and snippets.

@hughpyle
Last active December 15, 2015 04:50
Show Gist options
  • Save hughpyle/d0172eb66b96baa574e9 to your computer and use it in GitHub Desktop.
Save hughpyle/d0172eb66b96baa574e9 to your computer and use it in GitHub Desktop.
Solution to BigInteger square root puzzle http://corner.squareup.com/2013/03/puzzle-square-root.html
/*
Solution to biginteger square root puzzle http://corner.squareup.com/2013/03/puzzle-square-root.html
Uses a timing attack on BigInteger.divide(), which has a very large performance difference
depending whether the divisor is larger or smaller than this.
Times multiple iterations to try avoid jitter. Not at all optimized for speed :-)
*/
package square;
import java.math.BigInteger;
import java.security.SecureRandom;
public class SquareRootTest {
static boolean CSV = true; /* Output comma-separated values to graph later */
static boolean DEBUG = false; /* Debug the decision tree */
static boolean VERBOSE = false; /* Output the values */
static SquareRoot sqr = new SquareRoot();
static BigInteger n, nUpper, nLower;
static BigInteger TWO = BigInteger.valueOf(2L);
static long count=0;
static double oldML, oldMU, oldSL, oldSU;
static double newML, newMU, newSL, newSU;
static int FUDGEITER1=1;
static int FUDGEITER2=1;
static BigInteger sqrt(BigInteger x) {
// square roots of 0 and 1 are trivial and
// y == 0 will cause a divide-by-zero exception
if (x == BigInteger.ZERO || x == BigInteger.ONE) {
return x;
} // end if
BigInteger y;
// starting with y = x / 2 avoids magnitude issues with x squared
for (y = x.divide(TWO);
y.compareTo(x.divide(y)) > 0;
y = ((x.divide(y)).add(y)).divide(TWO));
return y;
}
static long tryroot( BigInteger bi )
{
long t = System.nanoTime();
for( int i=0; i<FUDGEITER1; i++ ) sqr.answer( bi );
return System.nanoTime() - t;
}
static void initBounds()
{
count = 0;
// Make a little big integer. Little is slow.
nLower = BigInteger.ONE;
// Make a big integer. Big is fast.
char[] val = new char[ SquareRoot.BITS ];
for( int i = 0; i<SquareRoot.BITS; i++ ) val[i]='9';
nUpper = new BigInteger( new String(val) ).pow(2);
}
static int test()
{
// Measure timing for the current point & bounds. Measure a few times to trim outliers.
long t = tryroot( n );
long tL = tryroot( nLower );
long tU = tryroot( nUpper );
for( int i=0; i<=FUDGEITER2; i++ )
{
t = Math.min( t, tryroot( n ) );
tL = Math.min( tL, tryroot( nLower ) );
tU = Math.min( tU, tryroot( nUpper ) );
}
if( t<=0 )
{
// Our machine is too fast to measure well, slow it down some
FUDGEITER1++;
if( DEBUG ) System.out.println("FUDGE=" + FUDGEITER1);
return 0;
}
// Accumulate for variance (Knuth TAOCP vol 2, 3ed p232).
// Keep doing this for all the bounds, we'll converge even if the wrong path is taken!
count++;
if( count==1 )
{
newML = tL;
oldML = tL;
oldSL = 0.0;
newMU = tU;
oldMU = tU;
oldSU = 0.0;
}
else
{
newML = oldML + (tL-oldML)/count;
newSL = oldSL + (tL-oldML)*(tL-newML);
oldML = newML;
oldSL = newSL;
newMU = oldMU + (tU-oldMU)/count;
newSU = oldSU + (tU-oldMU)*(tU-newMU);
oldMU = newMU;
oldSU = newSU;
}
// Throw away the first results
if( count<50 ) return 0;
double vL = newSL/(count-1);
double vU = newSU/(count-1);
if( CSV )
{
System.out.println( count + ", " + newML + ", " + vL + ", " + newMU + ", " + vU + ", " + t ); /* mean, variance, mean, variance, testvalue */
}
// If the lower and upper are too close to compare, we probably went wrong somewhere
if( (newMU-newML)*(newMU-newML) < (vL + vU) )
{
if( DEBUG ) System.out.println("Bounds overlap, mL=" + newML + ", mU=" + newMU );
FUDGEITER2++;
initBounds();
return 0;
}
// If within 1SD of the lower or upper bound, that's a good result
if( (t-newML)*(t-newML) < vL )
{
if( DEBUG ) System.out.println("Within lower bound variance, mL=" + newML + ", t=" + t + ", mU=" + newMU );
return -1;
}
if( (t-newMU)*(t-newMU) < vU )
{
if( DEBUG ) System.out.println("Within upper bound variance, mL=" + newML + ", t=" + t + ", mU=" + newMU );
return 1;
}
if( DEBUG ) System.out.println("Outside variance, mL=" + newML + ", t=" + t + ", mU=" + newMU );
return 0;
}
public static void main(String[] args)
{
initBounds();
// sqr.cheat();
// Start somewhere
// n = nUpper;
n = new BigInteger( SquareRoot.BITS, new SecureRandom() ).pow(2);
// Binary chop until they get slow.
for( int j=0; j<100000; j++ )
{
int t = test();
while( t==0 ) t = test();
if( t==-1 )
{
if( VERBOSE ) System.out.println( " bigger than " + n );
BigInteger n2 = n.add(nUpper).divide(TWO);
nLower = n;
n = n2;
}
else
{
if( VERBOSE ) System.out.println( " smaller than " + n );
BigInteger n2 = n.add(nLower).divide(TWO);
nUpper = n;
n = n2;
}
if( nLower.equals(n) )
{
n = n.add(BigInteger.ONE);
if( DEBUG ) System.out.println( "Finished at " + j + ",\nThe square is " + n );
System.out.println("Done!");
break;
}
if( VERBOSE ) System.out.println( "try " + n );
System.gc();
}
sqr.answer( sqrt(n) );
System.out.println("End.");
}
}
@hughpyle
Copy link
Author

OK, it's not particularly reliable yet. For 1000-digit values I'm hitting the right answer nearly all the time. For 10k-digit values my tests are only 50%, and it takes way longer to run than I'd like. I might spend time on this later, but currently happy with this for a proof of concept.

@hughpyle
Copy link
Author

Updated to stretch if the timing returns zero. More reliable and not necessarily any slower than before.

TIL the slow way: linux 2.4 has no sub-millisecond Java clock, so please use something more modern.

@hughpyle
Copy link
Author

Rev3 checks its convergence path can output comma-separated values for charting. I think this should be very solid. Now to run some tests :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment