Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
package square;
import java.math.BigInteger;
import java.util.*;
/**
* SquareRootSolver for puzzle http://corner.squareup.com/2013/03/puzzle-square-root.html
*
* Basic idea is to use binary search.
* Since we have only an SquareRoot.answer() that returns nothing we have to make up
* some indirect measurement to guess was has happened inside of this call.
*
* Measurement is performed by calling answer() with x^x values and looking on its execution time.
* We rely on following:
* - BigInteger.divide(...) method at the beginning checks whether dividend is less than divisor
* - with actual division performed its running time increases by at least 2 times (about 10 times indeed)
* - JIT is turned off with Compiler.disable()
* - execution time is averaged from [0.05, 0.95] confidence interval obtained from 100 samples
*
* Once we can compare number by times, we can do binary search and get result in about 10000 iterations
*
* Commands to run it on unix (was tested under mac os x and ubuntu)
cd $directoryWithSources
rm -rf classes
mkdir classes
javac -sourcepath . -d classes square/*.java
java -cp classes -Djava.security.manager square.SquareRootSolver
* Will print:
* Square root!
* and some stats
*
* @author vpd
*/
public class SquareRootSolver {
public static final BigInteger TWO = new BigInteger("2");
private SquareRootSolver() {
}
public static void main(final String[] args) throws Exception {
Compiler.disable(); // we have to disable JIT to be able to measure calls complexity
Wrapper.getInstance().runIterativeSearch();
}
private static class Wrapper {
static IterativeSquareRootSolver getInstance() {
return new IterativeSquareRootSolver();
}
}
private static class IterativeSquareRootSolver {
public static final int PROGRESS_REPORTING_STEP = 1000;
private Interval current;
private int currentIteration;
private int numberOfStepBacks;
public IterativeSquareRootSolver() {
current = new Interval(BigInteger.ONE, getMaxValue());
}
public void runIterativeSearch() {
Interval previousInterval = current;
for (currentIteration = 0; continueIteration(); currentIteration++) {
previousInterval = doIteration(previousInterval);
reportProgress();
}
makeSureAnswerIsFound();
reportFinish();
}
private boolean continueIteration() {
return current.length().compareTo(BigInteger.TEN) > 0;
}
private Interval doIteration(final Interval previousMeasurement) {
final Measurement measurement = makeMeasurement();
if (canUseMeasurement(measurement)) {
adjustForward(measurement);
return measurement.getInterval();
}
else {
adjustBackward(previousMeasurement);
return previousMeasurement;
}
}
private void makeSureAnswerIsFound() {
for (BigInteger i = current.getStart(); i.compareTo(current.getEnd()) <= 0; i = i.add(BigInteger.ONE)) {
SquareRoot.answer(i);
}
}
private boolean canUseMeasurement(final Measurement newMeasurement) {
return newMeasurement.shouldMoveStart() || newMeasurement.shouldMoveEnd();
}
private void adjustForward(final Measurement measurement) {
final BigInteger middle = measurement.getInterval().getMiddle();
if (measurement.shouldMoveStart()) {
current = new Interval(middle, current.getEnd());
}
else if (measurement.shouldMoveEnd()) {
current = new Interval(current.getStart(), middle);
}
}
private void adjustBackward(final Interval previousInterval) {
current = previousInterval;
numberOfStepBacks++;
}
private void reportProgress() {
if (currentIteration == 0) {
return;
}
if (currentIteration % PROGRESS_REPORTING_STEP == 0) {
final String prefix = "[step " + currentIteration + " / " + numberOfStepBacks + " step-backs]";
final int bitFieldLength = current.length().toByteArray().length * 8;
System.out.println(prefix + " search distance bit-field length is " + bitFieldLength);
}
}
private void reportFinish() {
System.out.println("Finished after " + currentIteration + " iterations");
}
private Measurement makeMeasurement() {
final Measurement measurement = new Measurement(current);
if (measurement.isSuspicious()) {
// redo (measured times can be affected by system load changes)
return new Measurement(current);
}
return measurement;
}
private static BigInteger getMaxValue() {
final byte[] bytes = new byte[SquareRoot.BITS / 8 + 1];
Arrays.fill(bytes, Byte.MAX_VALUE);
return new BigInteger(bytes);
}
}
private static class Interval {
private final BigInteger start;
private final BigInteger end;
private Interval(final BigInteger start, final BigInteger end) {
this.start = start;
this.end = end;
}
public BigInteger getStart() {
return start;
}
public BigInteger getEnd() {
return end;
}
public BigInteger getMiddle() {
return end.add(start).divide(TWO);
}
public BigInteger length() {
return end.subtract(start);
}
}
private static class Measurement {
public static final int NUMBER_OF_SAMPLES = 100;
public static final double CONFIDENCE_LEVEL = 0.05;
private final Interval interval;
private final long startPointTime;
private final long endPointTime;
private final long middlePointTime;
private Measurement(final Interval interval) {
this.interval = interval;
startPointTime = measureAverageAnswerTimeNs(interval.getStart());
endPointTime = measureAverageAnswerTimeNs(interval.getEnd());
middlePointTime = measureAverageAnswerTimeNs(interval.getMiddle());
}
private static long measureAverageAnswerTimeNs(final BigInteger integer) {
return measureAverageAnswerTimeNs(integer.pow(2), NUMBER_OF_SAMPLES, CONFIDENCE_LEVEL);
}
private static long measureAverageAnswerTimeNs(final BigInteger testedValue, final int numberOfSamples, final double confidenceLevel) {
final List<Long> samples = generateSampleTimesNs(testedValue, numberOfSamples);
return averageWithinConfidenceInterval(samples, confidenceLevel);
}
private static List<Long> generateSampleTimesNs(final BigInteger value, final int numberOfSamples) {
final List<Long> samples = new ArrayList<Long>(numberOfSamples);
for (int i = 0; i < numberOfSamples; i++) {
samples.add(measureAnswerTimeNs(value));
}
Collections.sort(samples);
return samples;
}
private static long averageWithinConfidenceInterval(
final List<Long> samples,
final double confidenceLevel
) {
final int numberOfSamples = samples.size();
final int startSampleIndex = (int) (confidenceLevel * numberOfSamples);
final double endSampleIndex = (int) ((1 - confidenceLevel) * numberOfSamples);
long sum = 0;
for (int i = startSampleIndex; i < endSampleIndex; i++) {
sum += samples.get(i);
}
return sum / numberOfSamples;
}
private static long measureAnswerTimeNs(final BigInteger value) {
final long start = System.nanoTime();
SquareRoot.answer(value);
return System.nanoTime() - start;
}
public boolean isStartGreaterThanEnd() {
return isSignificantlyFaster(startPointTime, endPointTime, 2);
}
public boolean isStartGreaterThanMiddle() {
return isSignificantlyFaster(startPointTime, middlePointTime, 2);
}
public boolean isMiddleGreaterThanEnd() {
return isSignificantlyFaster(middlePointTime, endPointTime, 2);
}
public boolean shouldMoveStart() {
return isSignificantlyFaster(endPointTime, middlePointTime, 2);
}
public boolean shouldMoveEnd() {
return isSignificantlyFaster(middlePointTime, startPointTime, 2);
}
private boolean isSuspicious() {
return
(isStartGreaterThanEnd() || isStartGreaterThanMiddle() || isMiddleGreaterThanEnd()) ||
(!shouldMoveStart() && !shouldMoveEnd());
}
public Interval getInterval() {
return interval;
}
private static boolean isSignificantlyFaster(final long faster, final long slower, final int level) {
return faster < slower / level;
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment