Skip to content

Instantly share code, notes, and snippets.

@klgraham
Created March 25, 2016 19:10
Show Gist options
  • Save klgraham/c1bc8fb6accb97e5aa6f to your computer and use it in GitHub Desktop.
Save klgraham/c1bc8fb6accb97e5aa6f to your computer and use it in GitHub Desktop.
Probability monad in Java 8
/**
* How to use Java 8 collections to create the probability monad.
*
* Created by klgraham on 10/25/15.
*/
import java.util.*;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;
public class Distributions
{
/**
* Using streams
*
* A stream is a sequence of elements. We can convert a standard Java collection
* into a stream and then perform various operations on the stream.
*/
public static void main(String[] args)
{
System.out.println("Sampling 10 doubles");
uniform.sample(10).forEach(System.out::println);
System.out.println("\nSampling 10 mapped doubles");
uniform.map(p -> p * 2).sample(10).forEach(System.out::println);
System.out.println("\nSampling from uniform boolean distribution");
tf(0.5).sample(10).forEach(System.out::println);
System.out.println("\nSampling from uniform bernoulli distribution, with p = 0.7");
bernoulli(0.7).sample(10).forEach(System.out::println);
System.out.println("\nProbability of a uniform double being less than 0.5:");
System.out.println(uniform.prob(u -> u < 0.5));
System.out.println("\nSample 10 Uniform variables above 0.3:");
uniform.given(u -> u > 0.3).sample(10).forEach(System.out::println);
System.out.println("\nGet 3 lists of 3 uniform variables");
uniform.repeat(3).sample(3).forEach(System.out::println);
System.out.println("\n6-sided die");
Distribution<Integer> die6 = discreteUniform(Arrays.asList(1, 2, 3, 4, 5, 6));
System.out.println(die6.sample(10));
System.out.println("Prob(3) = " + die6.prob(p -> p == 3));
System.out.println("\nPair of 6-sided dice");
Distribution<Integer> dice = die6.repeat(2).map(p -> p.get(0) + p.get(1));
System.out.println("Prob(7) = " + dice.prob(p -> p == 7));
System.out.println("Prob(11) = " + dice.prob(p -> p == 11));
System.out.println("Prob(4) = " + dice.prob(p -> p == 4));
System.out.println("\nPair of 6-sided dice, via flatmap");
Distribution<Integer> dice1 = die6.flatMap(d1 -> die6.map(d2 -> d1 + d2));
System.out.println("Prob(7) = " + dice1.prob(p -> p == 7));
System.out.println("Prob(11) = " + dice1.prob(p -> p == 11));
System.out.println("Prob(4) = " + dice1.prob(p -> p == 4));
System.out.println("\nMonty Hall problem");
System.out.println("Prob. that switching doors finds the prize: " +
montyHall().prob(pair -> pair._1 == pair._2));
System.out.println("\nNormal Distribution");
System.out.println("Mean: " + normal().mean());
System.out.println("StdDev: " + normal().stdDev());
}
// public Distributions() {
// }
/**
* Uniform distribution [0, 1]
*/
static Distribution<Double> uniform = new Distribution<Double>() {
private Random r = new Random();
@Override
Double get()
{
return r.nextDouble();
}
};
/**
* Boolean distribution
* @param p probability of true
* @return
*/
static Distribution<Boolean> tf(double p) {
return uniform.map(n -> n < p);
}
/**
* Bernoulli distribution
* 1 is success or a hit and 0 is failure or a miss
* @param p probability of 1
* @return distribution of 1s and 0s
*/
static Distribution<Integer> bernoulli(double p) {
return tf(p).map(b -> b ? 1 : 0);
}
static Distribution<Double> normal() {
return new Distribution<Double>() {
private Random r = new Random();
@Override
Double get()
{
return r.nextGaussian();
}
};
}
/**
* Discrete distribution
* @param values random values the distribution can take
* @param <A>
* @return
*/
static <A> Distribution<A> discreteUniform(Collection<A> values) {
List<A> vec = new ArrayList<A>(values);
return uniform.map(x -> vec.get((int) (x * vec.size())));
}
static Distribution<Integer> removePriceAndChoice(Set<Integer> doors, int p, int c) {
Set<Integer> d = new HashSet<Integer>(doors);
d.remove(p);
d.remove(c);
return discreteUniform(d);
}
static Distribution<Tuple<Integer, Integer>> montyHall()
{
Set<Integer> doors = new HashSet<>();
doors.addAll(Arrays.asList(1, 2, 3));
Distribution<Integer> prize = discreteUniform(doors);
Distribution<Integer> choice = discreteUniform(doors);
Distribution<Tuple<Integer, Integer>> mh =
prize.flatMap(p -> choice. // random prize location
flatMap(c -> removePriceAndChoice(doors, p, c). // random choice
flatMap(o -> removePriceAndChoice(doors, c, o). // open one of other doors
map(s -> new Tuple<Integer, Integer>(p, s))))); // switch
return mh;
}
}
/**
* Probability distribution
* @param <A> type of the random variable
*/
abstract class Distribution<A>
{
/**
* Choose a random variable of type A
* @return
*/
abstract A get();
/**
* Generate a list of n random variables of type A
* @param n number of random variables
* @return
*/
List<A> sample(Integer n)
{
return Collections.nCopies(n, 0).stream().map(p -> this.get()).collect(Collectors.toList());
}
/**
* Maps one Distribution into another
* @param f mapping function
* @param <B> type of variables in output distribution
* @return
*/
<B> Distribution<B> map(Function<A, B> f)
{
Distribution<A> dist = this;
return new Distribution<B>()
{
@Override
B get()
{
return f.apply(dist.get());
}
};
}
/**
* FlatMaps one Distribution into another
* @param f function mapping one value to a distribution
* @param <B> type of variables in output distribution
* @return
*/
<B> Distribution<B> flatMap(Function<A, Distribution<B>> f)
{
Distribution<A> dist = this;
return new Distribution<B>()
{
@Override
B get()
{
return f.apply(dist.get()).get();
}
};
}
private int N = 10000;
/**
* Probability of the predicate being true
* @param predicate
* @return
*/
double prob(Predicate<A> predicate)
{
return (double)this.sample(N).stream().filter(predicate).count() / (double)N;
}
/**
* Samples from the new distribution so that the result matches the predicate
* @param predicate
* @return
*/
Distribution<A> given(Predicate<A> predicate)
{
Distribution<A> dist = this;
return new Distribution<A>() {
A a = dist.get();
@Override
A get() {
return predicate.test(a) ? a : dist.get();
}
};
}
/**
* Creates a distribution of lists of samples of length n
* @param n
* @return
*/
Distribution<List<A>> repeat(int n)
{
Distribution<A> dist = this;
return new Distribution<List<A>>() {
@Override
List<A> get()
{
return dist.sample(n);
}
};
}
double mean()
{
double sum = 0;
for (A v : this.sample(N))
{
sum += Double.valueOf(v.toString());
}
return sum / (double)N;
}
double variance()
{
double sum = 0;
double sqrSum = 0;
for (A v : this.sample(N))
{
double vv = Double.valueOf(v.toString());
sum += vv;
sqrSum += vv * vv;
}
return (sqrSum - sum * sum / (double)N) / (double)(N-1);
}
double stdDev()
{
return Math.sqrt(this.variance());
}
}
class Tuple<T, U>
{
public final T _1;
public final U _2;
public Tuple(T arg1, U arg2) {
super();
this._1 = arg1;
this._2 = arg2;
}
@Override
public String toString() {
return String.format("(%s, %s)", _1, _2);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Tuple<?, ?> tuple = (Tuple<?, ?>) o;
if (!_1.equals(tuple._1)) return false;
return _2.equals(tuple._2);
}
@Override
public int hashCode() {
int result = _1.hashCode();
result = 31 * result + _2.hashCode();
return result;
}
}
//class Histogram
//{
// public static <T> Map<T, Long> frequencies(Stream<T> stream)
// {
// return stream.
// collect(Collectors.groupingBy(Function.identity(), Collectors.counting()));
// }
//
// public static <T> Map<T, Double> histogram(Stream<T> stream)
// {
// int N = (int)stream.count();
// Map<T, Double> hist = new HashMap<>();
// Map<T, Long> freqs = frequencies(stream);
//
// for (Map.Entry<T, Long> entry : freqs.entrySet())
// {
// hist.put(entry.getKey(), (double)entry.getValue() / (double)N);
// }
//
// return hist;
// }
//}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment