-
-
Save vpd/2f932bbd3b25e724be83 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package square; | |
import java.math.BigInteger; | |
import java.util.*; | |
/** | |
* SquareRootSolver for puzzle http://corner.squareup.com/2013/03/puzzle-square-root.html | |
* | |
* Basic idea is to use binary search. | |
* Since we have only an SquareRoot.answer() that returns nothing we have to make up | |
* some indirect measurement to guess was has happened inside of this call. | |
* | |
* Measurement is performed by calling answer() with x^x values and looking on its execution time. | |
* We rely on following: | |
* - BigInteger.divide(...) method at the beginning checks whether dividend is less than divisor | |
* - with actual division performed its running time increases by at least 2 times (about 10 times indeed) | |
* - JIT is turned off with Compiler.disable() | |
* - execution time is averaged from [0.05, 0.95] confidence interval obtained from 100 samples | |
* | |
* Once we can compare number by times, we can do binary search and get result in about 10000 iterations | |
* | |
* Commands to run it on unix (was tested under mac os x and ubuntu) | |
cd $directoryWithSources | |
rm -rf classes | |
mkdir classes | |
javac -sourcepath . -d classes square/*.java | |
java -cp classes -Djava.security.manager square.SquareRootSolver | |
* Will print: | |
* Square root! | |
* and some stats | |
* | |
* @author vpd | |
*/ | |
public class SquareRootSolver { | |
public static final BigInteger TWO = new BigInteger("2"); | |
private SquareRootSolver() { | |
} | |
public static void main(final String[] args) throws Exception { | |
Compiler.disable(); // we have to disable JIT to be able to measure calls complexity | |
Wrapper.getInstance().runIterativeSearch(); | |
} | |
private static class Wrapper { | |
static IterativeSquareRootSolver getInstance() { | |
return new IterativeSquareRootSolver(); | |
} | |
} | |
private static class IterativeSquareRootSolver { | |
public static final int PROGRESS_REPORTING_STEP = 1000; | |
private Interval current; | |
private int currentIteration; | |
private int numberOfStepBacks; | |
public IterativeSquareRootSolver() { | |
current = new Interval(BigInteger.ONE, getMaxValue()); | |
} | |
public void runIterativeSearch() { | |
Interval previousInterval = current; | |
for (currentIteration = 0; continueIteration(); currentIteration++) { | |
previousInterval = doIteration(previousInterval); | |
reportProgress(); | |
} | |
makeSureAnswerIsFound(); | |
reportFinish(); | |
} | |
private boolean continueIteration() { | |
return current.length().compareTo(BigInteger.TEN) > 0; | |
} | |
private Interval doIteration(final Interval previousMeasurement) { | |
final Measurement measurement = makeMeasurement(); | |
if (canUseMeasurement(measurement)) { | |
adjustForward(measurement); | |
return measurement.getInterval(); | |
} | |
else { | |
adjustBackward(previousMeasurement); | |
return previousMeasurement; | |
} | |
} | |
private void makeSureAnswerIsFound() { | |
for (BigInteger i = current.getStart(); i.compareTo(current.getEnd()) <= 0; i = i.add(BigInteger.ONE)) { | |
SquareRoot.answer(i); | |
} | |
} | |
private boolean canUseMeasurement(final Measurement newMeasurement) { | |
return newMeasurement.shouldMoveStart() || newMeasurement.shouldMoveEnd(); | |
} | |
private void adjustForward(final Measurement measurement) { | |
final BigInteger middle = measurement.getInterval().getMiddle(); | |
if (measurement.shouldMoveStart()) { | |
current = new Interval(middle, current.getEnd()); | |
} | |
else if (measurement.shouldMoveEnd()) { | |
current = new Interval(current.getStart(), middle); | |
} | |
} | |
private void adjustBackward(final Interval previousInterval) { | |
current = previousInterval; | |
numberOfStepBacks++; | |
} | |
private void reportProgress() { | |
if (currentIteration == 0) { | |
return; | |
} | |
if (currentIteration % PROGRESS_REPORTING_STEP == 0) { | |
final String prefix = "[step " + currentIteration + " / " + numberOfStepBacks + " step-backs]"; | |
final int bitFieldLength = current.length().toByteArray().length * 8; | |
System.out.println(prefix + " search distance bit-field length is " + bitFieldLength); | |
} | |
} | |
private void reportFinish() { | |
System.out.println("Finished after " + currentIteration + " iterations"); | |
} | |
private Measurement makeMeasurement() { | |
final Measurement measurement = new Measurement(current); | |
if (measurement.isSuspicious()) { | |
// redo (measured times can be affected by system load changes) | |
return new Measurement(current); | |
} | |
return measurement; | |
} | |
private static BigInteger getMaxValue() { | |
final byte[] bytes = new byte[SquareRoot.BITS / 8 + 1]; | |
Arrays.fill(bytes, Byte.MAX_VALUE); | |
return new BigInteger(bytes); | |
} | |
} | |
private static class Interval { | |
private final BigInteger start; | |
private final BigInteger end; | |
private Interval(final BigInteger start, final BigInteger end) { | |
this.start = start; | |
this.end = end; | |
} | |
public BigInteger getStart() { | |
return start; | |
} | |
public BigInteger getEnd() { | |
return end; | |
} | |
public BigInteger getMiddle() { | |
return end.add(start).divide(TWO); | |
} | |
public BigInteger length() { | |
return end.subtract(start); | |
} | |
} | |
private static class Measurement { | |
public static final int NUMBER_OF_SAMPLES = 100; | |
public static final double CONFIDENCE_LEVEL = 0.05; | |
private final Interval interval; | |
private final long startPointTime; | |
private final long endPointTime; | |
private final long middlePointTime; | |
private Measurement(final Interval interval) { | |
this.interval = interval; | |
startPointTime = measureAverageAnswerTimeNs(interval.getStart()); | |
endPointTime = measureAverageAnswerTimeNs(interval.getEnd()); | |
middlePointTime = measureAverageAnswerTimeNs(interval.getMiddle()); | |
} | |
private static long measureAverageAnswerTimeNs(final BigInteger integer) { | |
return measureAverageAnswerTimeNs(integer.pow(2), NUMBER_OF_SAMPLES, CONFIDENCE_LEVEL); | |
} | |
private static long measureAverageAnswerTimeNs(final BigInteger testedValue, final int numberOfSamples, final double confidenceLevel) { | |
final List<Long> samples = generateSampleTimesNs(testedValue, numberOfSamples); | |
return averageWithinConfidenceInterval(samples, confidenceLevel); | |
} | |
private static List<Long> generateSampleTimesNs(final BigInteger value, final int numberOfSamples) { | |
final List<Long> samples = new ArrayList<Long>(numberOfSamples); | |
for (int i = 0; i < numberOfSamples; i++) { | |
samples.add(measureAnswerTimeNs(value)); | |
} | |
Collections.sort(samples); | |
return samples; | |
} | |
private static long averageWithinConfidenceInterval( | |
final List<Long> samples, | |
final double confidenceLevel | |
) { | |
final int numberOfSamples = samples.size(); | |
final int startSampleIndex = (int) (confidenceLevel * numberOfSamples); | |
final double endSampleIndex = (int) ((1 - confidenceLevel) * numberOfSamples); | |
long sum = 0; | |
for (int i = startSampleIndex; i < endSampleIndex; i++) { | |
sum += samples.get(i); | |
} | |
return sum / numberOfSamples; | |
} | |
private static long measureAnswerTimeNs(final BigInteger value) { | |
final long start = System.nanoTime(); | |
SquareRoot.answer(value); | |
return System.nanoTime() - start; | |
} | |
public boolean isStartGreaterThanEnd() { | |
return isSignificantlyFaster(startPointTime, endPointTime, 2); | |
} | |
public boolean isStartGreaterThanMiddle() { | |
return isSignificantlyFaster(startPointTime, middlePointTime, 2); | |
} | |
public boolean isMiddleGreaterThanEnd() { | |
return isSignificantlyFaster(middlePointTime, endPointTime, 2); | |
} | |
public boolean shouldMoveStart() { | |
return isSignificantlyFaster(endPointTime, middlePointTime, 2); | |
} | |
public boolean shouldMoveEnd() { | |
return isSignificantlyFaster(middlePointTime, startPointTime, 2); | |
} | |
private boolean isSuspicious() { | |
return | |
(isStartGreaterThanEnd() || isStartGreaterThanMiddle() || isMiddleGreaterThanEnd()) || | |
(!shouldMoveStart() && !shouldMoveEnd()); | |
} | |
public Interval getInterval() { | |
return interval; | |
} | |
private static boolean isSignificantlyFaster(final long faster, final long slower, final int level) { | |
return faster < slower / level; | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment