Skip to content

Instantly share code, notes, and snippets.

@fredeil
Created April 20, 2018 12:07
Show Gist options
  • Save fredeil/7421ae2f28452558736e54e35b1e7e30 to your computer and use it in GitHub Desktop.
Save fredeil/7421ae2f28452558736e54e35b1e7e30 to your computer and use it in GitHub Desktop.
namespace ConsoleTester
{
class Program
{
static void Main(string[] args)
{
int n_machines = 3;
double[] probs = new double[n_machines];
double[] means = new double[] { 0.3, 0.5, 0.7 };
int[] S = new int[n_machines];
int[] F = new int[n_machines];
Random rnd = new Random(4);
BetaSampler bs = new BetaSampler(2);
for (int trial = 0; trial < 10; ++trial)
{
Console.WriteLine("Trial " + trial);
for (int i = 0; i < n_machines; ++i)
{
probs[i] = bs.Sample(S[i] + 1.0, F[i] + 1.0);
}
Console.Write("Sampling probs: ");
for (int i = 0; i < n_machines; ++i)
{
Console.Write(probs[i].ToString("F4") + " ");
}
Console.WriteLine("");
int machine = 0;
double highProb = 0.0;
for (int i = 0; i < n_machines; ++i)
{
if (probs[i] > highProb)
{
highProb = probs[i];
machine = i;
}
}
Console.Write("Playing machine " + machine);
double p = rnd.NextDouble();
if (p < means[machine])
{
Console.WriteLine(" -- win");
++S[machine];
}
else
{
Console.WriteLine(" -- lose");
++F[machine];
}
}
Console.WriteLine("Final estimates of means: ");
for (int i = 0; i < n_machines; ++i)
{
double u = (S[i] * 1.0) / (S[i] + F[i]);
Console.WriteLine(u.ToString("F4") + " ");
}
Console.WriteLine("Number times machine played:");
for (int i = 0; i < n_machines; ++i)
{
int ct = S[i] + F[i];
Console.WriteLine(ct + " ");
}
Console.WriteLine("End demo ");
Console.ReadLine();
}
}
public class BetaSampler
{
public Random rnd;
public BetaSampler(int seed)
{
rnd = new Random(seed);
}
public double Sample(double a, double b)
{
double alpha = a + b;
double beta = 0.0;
double u1, u2, w, v = 0.0;
beta = Math.Min(a, b) <= 1.0 ? Math.Max(1 / a, 1 / b) : Math.Sqrt((alpha - 2.0) / (2 * a * b - alpha));
double gamma = a + 1 / beta;
while (true)
{
u1 = rnd.NextDouble();
u2 = rnd.NextDouble();
v = beta * Math.Log(u1 / (1 - u1));
w = a * Math.Exp(v);
double tmp = Math.Log(alpha / (b + w));
if (alpha * tmp + (gamma * v) - 1.3862944 >= Math.Log(u1 * u1 * u2))
{
break;
}
}
return w / (b + w);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment