Skip to content

Instantly share code, notes, and snippets.

@matthewfeickert
Last active November 4, 2017 20:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save matthewfeickert/8c6bcf26462b6e06983aee08597da2db to your computer and use it in GitHub Desktop.
Save matthewfeickert/8c6bcf26462b6e06983aee08597da2db to your computer and use it in GitHub Desktop.
Quick Python demonstration of Uniform sampling problem (https://twitter.com/fermatslibrary/status/924263998589145090)
#!/usr/bin/env python
"""
Problem Statement: Sample from the Uniform distribution over the range [0,1]
until the sum of the numbers sampled is greater than 1. On average, how
many samples are taken?
Answer: e
Problem Source: https://twitter.com/fermatslibrary/status/924263998589145090
Author: Matthew Feickert
Date: 2016-10-28
"""
import sys
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import itertools # for fast looping
def simulate_sampling(n_trials):
"""
Simulate experiments
Args:
n_trials: `int` The number of trials
Returns:
sum(counts): `int`
counts: `array of ints`
"""
counts = []
for _ in itertools.repeat(None, n_trials):
n_samples = 0
sum_ = 0
while sum_ <= 1:
sum_ += np.random.uniform(0, 1)
n_samples += 1
counts.append(n_samples)
return sum(counts), counts
def main(n_trials=100000):
n_samples, counts = simulate_sampling(n_trials)
result = n_samples / n_trials
print('Average number of samples taken over {} trials: {}'.format(
n_trials, result))
print('Difference between result and e (~{:.6f}): {:.6f}'.format(
np.exp(1),
np.absolute(result - np.exp(1))))
print('Relative difference between result and e: {:.6f}'.format(
np.absolute(1 - (result / np.exp(1)))))
# Plot results
x = list(range(2, 10))
relative_counts = list(counts.count(number) for number in x)
plt.plot(x, relative_counts, 'ro', linewidth=1,
color='black', markerfacecolor='blue')
line_result = plt.axvline(
x=result, color='black', label='mean number of samples: {}'.format(result))
plt.xlabel('Number of samples')
plt.ylabel('Relative count')
# Legend
handles = [line_result]
handles.append(mpatches.Patch(
color='none', label='e ~ {:.6f}'.format(np.exp(1))))
plt.legend(handles=handles)
plt.savefig('sample_uniform.png')
if __name__ == '__main__':
if len(sys.argv) > 1:
main(int(sys.argv[1]))
else:
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment