Last active
October 15, 2017 23:19
-
-
Save unixpickle/2150b7bfaf531604640892b77319fea9 to your computer and use it in GitHub Desktop.
Probability contest
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
An implementation of "probability contests". | |
In the most basic probability contest, you have two random | |
variables X and Y and want to know P(X > Y), i.e. the | |
probability that X "wins". | |
It would be desirable if either X or Y always won the | |
probability contest, even if X = Y. | |
Thus, we can define the win probability as: | |
P(X wins) = P(X > Y) + 0.5*P(X = Y) | |
In other words, ties are broken randomly. | |
To define a probability contest over more than two r.v.'s, | |
we can say that X wins if it is greater than every other | |
random variable. | |
If there is a tie, we select randomly amongst the winners. | |
""" | |
from collections import Counter | |
import random | |
import numpy as np | |
def sample_winner(*dist_samples): | |
""" | |
Sample a winner in a probability contest. | |
Arguments: | |
dist_samples: for each distribution, a set of random | |
samples from that distribution. | |
The distribution is defined by its samples, so the | |
more samples you provide, the more accurate the | |
results. | |
Returns: | |
The index of the random variable that won. | |
""" | |
assert len(dist_samples) > 1 | |
samples = np.array([random.choice(x) for x in dist_samples]) | |
max_val = np.amax(samples) | |
max_indices = np.argwhere(samples == max_val).flatten() | |
if len(max_indices) == 1: | |
return max_indices[0] | |
return np.random.choice(max_indices) | |
def win_probabilities(*dist_samples): | |
""" | |
Compute the win probability of each distribution in a | |
probability contest. | |
Arguments: | |
dist_samples: for each distribution, a set of random | |
samples from that distribution. | |
The distribution is defined by its samples, so the | |
more samples you provide, the more accurate the | |
results. | |
Returns: | |
A tuple of probabilities, one per distribution. | |
The sum of the resulting probabilities is always 1. | |
""" | |
assert len(dist_samples) > 1 | |
dists = [_SampleDistribution(x) for x in dist_samples] | |
return tuple(_single_win_prob(dist, dists[:i]+dists[i+1:]) | |
for i, dist in enumerate(dists)) | |
def _single_win_prob(dist, other_dists): | |
""" | |
Compute the probability of the distribution winning | |
against all the other distributions. | |
All distributions are _SampleDistributions. | |
""" | |
win_prob = 0 | |
for value, value_prob in zip(dist.sorted, dist.probs): | |
other_probs = [d.prob_less_equal(value) for d in other_dists] | |
less_probs, equal_probs = zip(*other_probs) | |
win_prob += value_prob * _value_win_prob(less_probs, equal_probs) | |
return win_prob | |
def _value_win_prob(less_probs, equal_probs, cur_prob=1, num_equal=1): | |
""" | |
Compute a probability of winning, given the | |
comparisons to some other distributions. | |
""" | |
if not less_probs: | |
return cur_prob * (1 / num_equal) | |
elif cur_prob == 0: | |
return 0.0 | |
less_branch = _value_win_prob(less_probs[1:], equal_probs[1:], | |
cur_prob=cur_prob*less_probs[0], | |
num_equal=num_equal) | |
equal_branch = _value_win_prob(less_probs[1:], equal_probs[1:], | |
cur_prob=cur_prob*equal_probs[0], | |
num_equal=num_equal+1) | |
return less_branch + equal_branch | |
# pylint: disable=R0903 | |
class _SampleDistribution: | |
""" | |
A probability distribution over real numbers. | |
""" | |
def __init__(self, samples): | |
# pylint: disable=C1801 | |
assert len(samples) > 0 | |
self.sorted = np.array(sorted(set(samples))) | |
scale = 1 / len(samples) | |
count = Counter(samples) | |
self.probs = np.array([count[x]*scale for x in self.sorted]) | |
self._cumulative = np.cumsum(self.probs) | |
# Used to optimize prob_less(). | |
self._cur_offset = 0 | |
def prob_less_equal(self, value): | |
""" | |
Compute two probabilities: the probability that we | |
sample a value less than the argument, and the | |
probability that we sample exactly the argument. | |
Returns: | |
A tuple (less_prob, equal_prob). | |
This is optimized to be called repeatedly with | |
successively larger values. | |
""" | |
if self.sorted[-1] < value: | |
return 1.0, 0.0 | |
elif self.sorted[0] == value: | |
return 0.0, self.probs[0] | |
elif self.sorted[0] > value: | |
return 0.0, 0.0 | |
elif self.sorted[self._cur_offset] >= value: | |
# Arguments weren't monotonically increasing. | |
self._cur_offset = 0 | |
while self.sorted[self._cur_offset+1] < value: | |
self._cur_offset += 1 | |
less_prob = self._cumulative[self._cur_offset] | |
equal_prob = 0.0 | |
if self.sorted[self._cur_offset+1] == value: | |
equal_prob = self.probs[self._cur_offset+1] | |
return less_prob, equal_prob |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Tests for probability contests. | |
""" | |
from collections import Counter | |
import unittest | |
import numpy as np | |
from contest import sample_winner, win_probabilities | |
class WinProbabilitiesTest(unittest.TestCase): | |
""" | |
Tests for win_probabilities(). | |
""" | |
def test_known_cases(self): | |
""" | |
Test cases where the answers are easy to intuit. | |
""" | |
cases = [ | |
([1, 2, 3], [4, 5, 6]), | |
([1, 2, 5], [4, 6, 7]), | |
([1, 2, 5], [5, 6, 7]), | |
] | |
answers = [(0.0, 1.0), (1/9, 8/9), (1/18, 17/18)] | |
for case, answer in zip(cases, answers): | |
actual = win_probabilities(*case) | |
self.assertTrue(np.allclose(np.array(answer), np.array(actual))) | |
def test_sample_equiv(self): | |
""" | |
Test that the probabilities from win_probabilities | |
align with those from sample_winner. | |
""" | |
# The test is non-deterministic, so let's make | |
# sure we use a seed that passes (most do). | |
np.random.seed(1) | |
for test_case in _test_cases(): | |
actual = np.array(win_probabilities(*test_case)) | |
sampled = np.array(_approx_probabilities(test_case)) | |
self.assertTrue(np.allclose(actual, sampled, rtol=1e-2, atol=1e-2)) | |
def test_self_comparisons(self): | |
""" | |
Test that comparisons between a distribution and | |
itself always yield even results. | |
""" | |
for samples in [x for y in _test_cases() for x in y]: | |
for num_repeats in range(2, 5): | |
probs = win_probabilities(*([samples]*num_repeats)) | |
self.assertTrue(np.allclose(probs, [1.0/num_repeats]*num_repeats)) | |
def test_sum_1(self): | |
""" | |
Test that win probabilities always sum up to 1. | |
""" | |
for test_case in _test_cases(): | |
total = sum(win_probabilities(*test_case)) | |
self.assertTrue(np.allclose(total, 1.0)) | |
def test_win_probabilities_speed(benchmark): | |
""" | |
A benchmark for win_probabilities() on a real-life | |
use-case. | |
""" | |
dists = [np.random.randint(0, 200, size=(100,)) for _ in range(3)] | |
benchmark(lambda: win_probabilities(*dists)) | |
def _test_cases(): | |
""" | |
Generate tuples of distribution samples for testing. | |
""" | |
return [ | |
([1, 2, 3, 4, 5], [1, 2, 3, 4, 5]), | |
([1, 2, 3, 4, 5], [1, 2, 3, 4, 5, 6]), | |
([2, 3, 4, 5], [1, 2, 3, 4, 5, 6]), | |
([2, 3, 4, 5, 6], [1, 2, 3, 4, 5]), | |
([2, 3, 4, 5, 6], [2, 3, 4, 5]), | |
([1, 2, 2, 3, 4], [1, 3, 1, 1, 1]), | |
([1, 1, 1], [2, 1, 2, 1], [3, 1, 2, 3, 2]), | |
(np.random.randint(0, 20, size=(20,)), | |
np.random.randint(1, 22, size=(15,)), | |
np.random.randint(15, 30, size=(32,))), | |
(np.random.normal(size=(15,)), | |
np.random.normal(size=(30,)), | |
np.random.normal(size=(5,)), | |
np.random.normal(size=(22,))) | |
] | |
def _approx_probabilities(dists, num_samples=20000): | |
""" | |
Approximate the win probabilities by sampling. | |
""" | |
counts = Counter([sample_winner(*dists) for _ in range(num_samples)]) | |
return tuple(counts[i]/num_samples for i in range(len(dists))) | |
if __name__ == '__main__': | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment