Skip to content

Instantly share code, notes, and snippets.

@joriki
Created May 13, 2016 14:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save joriki/e15ed8baae37103b2258e7640211975f to your computer and use it in GitHub Desktop.
Save joriki/e15ed8baae37103b2258e7640211975f to your computer and use it in GitHub Desktop.
Decide whether rerolling the highest value on a die can improve the expected value; see http://math.stackexchange.com/questions/1118263.
import java.math.BigInteger;
import java.util.Arrays;
public class Question1118263 {
final static int [] values = {1,2,3,4,5,6};
final static int nsides = values.length;
final static int ndice = 1000;
static boolean nonHighest;
static SimpleRational [] v = new SimpleRational [ndice]; // game values
static SimpleRational [] w = new SimpleRational [ndice]; // values of games preceded by free extra roll without sixes or requirement to fix, scaled by 5^n
static BigInteger [] [] powers = new BigInteger [nsides + 1] [ndice];
static int n;
public static void main (String [] args) throws Exception {
for (int i = 0;i <= nsides;i++) {
powers [i] [0] = BigInteger.ONE;
BigInteger factor = BigInteger.valueOf (i);
for (int j = 1;j < ndice;j++)
powers [i] [j] = powers [i] [j - 1].multiply (factor);
}
SimpleRational zero = SimpleRational.zero (nsides);
Arrays.fill (v,zero);
Arrays.fill (w,zero);
BigInteger highDie = BigInteger.valueOf (values [nsides - 1]);
// calculate with explicit dice counts until w [n] = v [n]
for (n = 1;n < ndice;n++) {
recurse (0,n,new int [nsides]);
v [n] = v [n].divideByPower (n);
w [n - 1] = w [n - 1].divide (n).subtract (powers [nsides - 1] [n - 1].multiply (highDie));
if (n > 1 && w [n - 1].equals (v [n - 1].multiply (powers [nsides - 1] [n - 1])))
break;
}
v [n] = zero;
// now apply the recurrence relation, counting only sixes
for (;n < ndice;n++) {
for (int nlow = 0;nlow < n;nlow++)
v [n] = v [n].add (w [nlow].add (powers [nsides - 1] [nlow].multiply (BigInteger.valueOf ((n - nlow) * values [nsides - 1]))).multiply (binomial (n,nlow)));
v [n] = v [n].add (v [n - 1].multiply (powers [nsides - 1] [n]));
for (int k = 0;k < nsides - 1;k++)
v [n] = v [n].add (BigInteger.valueOf (values [k]).multiply (powers [k + 1] [n].subtract (powers [k] [n])));
v [n] = v [n].divideByPower (n);
w [n] = v [n].multiply (powers [nsides - 1] [n]);
System.out.println (" " + n + " : " + v [n].subtract (v [n - 1]).subtract (values [nsides - 1]).doubleValue ());
}
}
static void recurse (int depth,int left,int [] counts) {
if (depth == nsides - 1) {
counts [depth] = left;
int sum = 0;
for (int i = 0;i < counts.length;i++)
sum += counts [i] * values [i];
int index = 0;
int count = 0;
int value = 0;
SimpleRational max = SimpleRational.zero (nsides);
for (int reroll = 0;reroll < n;reroll++) {
SimpleRational newValue = v [reroll].add (sum);
if (newValue.compareTo (max) > 0)
max = newValue;
while (count == 0) {
value = values [index];
count = counts [index++];
}
count--;
sum -= value;
}
BigInteger multinomial = BigInteger.ONE;
int upper = n;
for (int i = 0;i < counts.length;i++) {
multinomial = multinomial.multiply (binomial (upper,counts [i]));
upper -= counts [i];
}
SimpleRational term = max.multiply (multinomial);
v [n] = v [n].add (term);
if (counts [nsides - 1] == 1)
w [n - 1] = w [n - 1].add (term);
}
else
for (counts [depth] = 0;counts [depth] <= left;counts [depth]++)
recurse (depth + 1,left - counts [depth],counts);
}
static BigInteger [] [] binomials = new BigInteger [ndice] [ndice];
static BigInteger binomial (int n,int k) {
if (k > n)
return BigInteger.ZERO;
if (binomials [n] [k] == null) {
binomials [n] [k] = BigInteger.ONE;
for (int i = 0;i < k;i++)
binomials [n] [k] = binomials [n] [k].multiply (BigInteger.valueOf (n - i)).divide (BigInteger.valueOf (i + 1));
}
return binomials [n] [k];
}
}
// represents numerator / base^exponent
class SimpleRational implements Comparable<SimpleRational> {
BigInteger numerator;
BigInteger base;
int exponent;
public SimpleRational (long numerator,long base,int exponent) {
this (BigInteger.valueOf (numerator),base,exponent);
}
public SimpleRational (BigInteger numerator,long base,int exponent) {
this (numerator,BigInteger.valueOf (base),exponent);
}
public SimpleRational (BigInteger numerator,BigInteger base,int exponent) {
this.numerator = numerator;
this.base = base;
this.exponent = exponent;
}
public SimpleRational negate () {
return new SimpleRational (numerator.negate (),base,exponent);
}
public SimpleRational add (long n) {
return add (BigInteger.valueOf (n));
}
public SimpleRational subtract (long n) {
return subtract (BigInteger.valueOf (n));
}
public SimpleRational subtract (BigInteger n) {
return add (n.negate ());
}
public SimpleRational add (BigInteger n) {
return new SimpleRational (numerator.add (n.multiply (base.pow (exponent))),base,exponent);
}
public SimpleRational add (SimpleRational r) {
int d = exponent - r.exponent;
if (d < 0)
return r.add (this);
if (!r.base.equals (base))
throw new IllegalArgumentException ();
return new SimpleRational (numerator.add (r.numerator.multiply (base.pow (d))),base,exponent);
}
public SimpleRational subtract (SimpleRational r) {
return add (r.negate ());
}
public SimpleRational multiply (long n) {
return multiply (BigInteger.valueOf (n));
}
public SimpleRational multiply (BigInteger n) {
return new SimpleRational (numerator.multiply (n),base,exponent);
}
public SimpleRational multiply (SimpleRational r) {
if (!r.base.equals (base))
throw new IllegalArgumentException ();
return new SimpleRational (numerator.multiply (r.numerator),base,exponent + r.exponent);
}
public SimpleRational divide (long n) {
return divide (BigInteger.valueOf (n));
}
public SimpleRational divide (BigInteger n) {
BigInteger newNumerator = numerator.divide (n);
if (!newNumerator.multiply (n).equals (numerator))
throw new IllegalArgumentException ();
return new SimpleRational (newNumerator,base,exponent);
}
public SimpleRational divideByPower (int power) {
return new SimpleRational (numerator,base,exponent + power);
}
public double logarithm () {
int shift = Math.max (0,numerator.bitLength () - 64);
return Math.log (numerator.shiftRight (shift).doubleValue ()) + shift * Math.log (2) - exponent * Math.log (base.doubleValue ());
}
public double doubleValue () {
return numerator.signum () < 0 ? -negate ().doubleValue () : Math.exp (logarithm ());
}
public static SimpleRational zero (long base) {
return new SimpleRational (0,base,0);
}
public int compareTo (SimpleRational max) {
return subtract (max).numerator.signum ();
}
public boolean equals (Object o) {
return compareTo ((SimpleRational) o) == 0;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment