Created
May 13, 2016 14:23
-
-
Save joriki/e15ed8baae37103b2258e7640211975f to your computer and use it in GitHub Desktop.
Decide whether rerolling the highest value on a die can improve the expected value; see http://math.stackexchange.com/questions/1118263.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.math.BigInteger; | |
import java.util.Arrays; | |
public class Question1118263 { | |
final static int [] values = {1,2,3,4,5,6}; | |
final static int nsides = values.length; | |
final static int ndice = 1000; | |
static boolean nonHighest; | |
static SimpleRational [] v = new SimpleRational [ndice]; // game values | |
static SimpleRational [] w = new SimpleRational [ndice]; // values of games preceded by free extra roll without sixes or requirement to fix, scaled by 5^n | |
static BigInteger [] [] powers = new BigInteger [nsides + 1] [ndice]; | |
static int n; | |
public static void main (String [] args) throws Exception { | |
for (int i = 0;i <= nsides;i++) { | |
powers [i] [0] = BigInteger.ONE; | |
BigInteger factor = BigInteger.valueOf (i); | |
for (int j = 1;j < ndice;j++) | |
powers [i] [j] = powers [i] [j - 1].multiply (factor); | |
} | |
SimpleRational zero = SimpleRational.zero (nsides); | |
Arrays.fill (v,zero); | |
Arrays.fill (w,zero); | |
BigInteger highDie = BigInteger.valueOf (values [nsides - 1]); | |
// calculate with explicit dice counts until w [n] = v [n] | |
for (n = 1;n < ndice;n++) { | |
recurse (0,n,new int [nsides]); | |
v [n] = v [n].divideByPower (n); | |
w [n - 1] = w [n - 1].divide (n).subtract (powers [nsides - 1] [n - 1].multiply (highDie)); | |
if (n > 1 && w [n - 1].equals (v [n - 1].multiply (powers [nsides - 1] [n - 1]))) | |
break; | |
} | |
v [n] = zero; | |
// now apply the recurrence relation, counting only sixes | |
for (;n < ndice;n++) { | |
for (int nlow = 0;nlow < n;nlow++) | |
v [n] = v [n].add (w [nlow].add (powers [nsides - 1] [nlow].multiply (BigInteger.valueOf ((n - nlow) * values [nsides - 1]))).multiply (binomial (n,nlow))); | |
v [n] = v [n].add (v [n - 1].multiply (powers [nsides - 1] [n])); | |
for (int k = 0;k < nsides - 1;k++) | |
v [n] = v [n].add (BigInteger.valueOf (values [k]).multiply (powers [k + 1] [n].subtract (powers [k] [n]))); | |
v [n] = v [n].divideByPower (n); | |
w [n] = v [n].multiply (powers [nsides - 1] [n]); | |
System.out.println (" " + n + " : " + v [n].subtract (v [n - 1]).subtract (values [nsides - 1]).doubleValue ()); | |
} | |
} | |
static void recurse (int depth,int left,int [] counts) { | |
if (depth == nsides - 1) { | |
counts [depth] = left; | |
int sum = 0; | |
for (int i = 0;i < counts.length;i++) | |
sum += counts [i] * values [i]; | |
int index = 0; | |
int count = 0; | |
int value = 0; | |
SimpleRational max = SimpleRational.zero (nsides); | |
for (int reroll = 0;reroll < n;reroll++) { | |
SimpleRational newValue = v [reroll].add (sum); | |
if (newValue.compareTo (max) > 0) | |
max = newValue; | |
while (count == 0) { | |
value = values [index]; | |
count = counts [index++]; | |
} | |
count--; | |
sum -= value; | |
} | |
BigInteger multinomial = BigInteger.ONE; | |
int upper = n; | |
for (int i = 0;i < counts.length;i++) { | |
multinomial = multinomial.multiply (binomial (upper,counts [i])); | |
upper -= counts [i]; | |
} | |
SimpleRational term = max.multiply (multinomial); | |
v [n] = v [n].add (term); | |
if (counts [nsides - 1] == 1) | |
w [n - 1] = w [n - 1].add (term); | |
} | |
else | |
for (counts [depth] = 0;counts [depth] <= left;counts [depth]++) | |
recurse (depth + 1,left - counts [depth],counts); | |
} | |
static BigInteger [] [] binomials = new BigInteger [ndice] [ndice]; | |
static BigInteger binomial (int n,int k) { | |
if (k > n) | |
return BigInteger.ZERO; | |
if (binomials [n] [k] == null) { | |
binomials [n] [k] = BigInteger.ONE; | |
for (int i = 0;i < k;i++) | |
binomials [n] [k] = binomials [n] [k].multiply (BigInteger.valueOf (n - i)).divide (BigInteger.valueOf (i + 1)); | |
} | |
return binomials [n] [k]; | |
} | |
} | |
// represents numerator / base^exponent | |
class SimpleRational implements Comparable<SimpleRational> { | |
BigInteger numerator; | |
BigInteger base; | |
int exponent; | |
public SimpleRational (long numerator,long base,int exponent) { | |
this (BigInteger.valueOf (numerator),base,exponent); | |
} | |
public SimpleRational (BigInteger numerator,long base,int exponent) { | |
this (numerator,BigInteger.valueOf (base),exponent); | |
} | |
public SimpleRational (BigInteger numerator,BigInteger base,int exponent) { | |
this.numerator = numerator; | |
this.base = base; | |
this.exponent = exponent; | |
} | |
public SimpleRational negate () { | |
return new SimpleRational (numerator.negate (),base,exponent); | |
} | |
public SimpleRational add (long n) { | |
return add (BigInteger.valueOf (n)); | |
} | |
public SimpleRational subtract (long n) { | |
return subtract (BigInteger.valueOf (n)); | |
} | |
public SimpleRational subtract (BigInteger n) { | |
return add (n.negate ()); | |
} | |
public SimpleRational add (BigInteger n) { | |
return new SimpleRational (numerator.add (n.multiply (base.pow (exponent))),base,exponent); | |
} | |
public SimpleRational add (SimpleRational r) { | |
int d = exponent - r.exponent; | |
if (d < 0) | |
return r.add (this); | |
if (!r.base.equals (base)) | |
throw new IllegalArgumentException (); | |
return new SimpleRational (numerator.add (r.numerator.multiply (base.pow (d))),base,exponent); | |
} | |
public SimpleRational subtract (SimpleRational r) { | |
return add (r.negate ()); | |
} | |
public SimpleRational multiply (long n) { | |
return multiply (BigInteger.valueOf (n)); | |
} | |
public SimpleRational multiply (BigInteger n) { | |
return new SimpleRational (numerator.multiply (n),base,exponent); | |
} | |
public SimpleRational multiply (SimpleRational r) { | |
if (!r.base.equals (base)) | |
throw new IllegalArgumentException (); | |
return new SimpleRational (numerator.multiply (r.numerator),base,exponent + r.exponent); | |
} | |
public SimpleRational divide (long n) { | |
return divide (BigInteger.valueOf (n)); | |
} | |
public SimpleRational divide (BigInteger n) { | |
BigInteger newNumerator = numerator.divide (n); | |
if (!newNumerator.multiply (n).equals (numerator)) | |
throw new IllegalArgumentException (); | |
return new SimpleRational (newNumerator,base,exponent); | |
} | |
public SimpleRational divideByPower (int power) { | |
return new SimpleRational (numerator,base,exponent + power); | |
} | |
public double logarithm () { | |
int shift = Math.max (0,numerator.bitLength () - 64); | |
return Math.log (numerator.shiftRight (shift).doubleValue ()) + shift * Math.log (2) - exponent * Math.log (base.doubleValue ()); | |
} | |
public double doubleValue () { | |
return numerator.signum () < 0 ? -negate ().doubleValue () : Math.exp (logarithm ()); | |
} | |
public static SimpleRational zero (long base) { | |
return new SimpleRational (0,base,0); | |
} | |
public int compareTo (SimpleRational max) { | |
return subtract (max).numerator.signum (); | |
} | |
public boolean equals (Object o) { | |
return compareTo ((SimpleRational) o) == 0; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment