Skip to content

Instantly share code, notes, and snippets.

@bagrow
Created November 7, 2012 00:23
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bagrow/4028656 to your computer and use it in GitHub Desktop.
Save bagrow/4028656 to your computer and use it in GitHub Desktop.
Sample non-uniformly from a set of choices
#!/usr/bin/env python
# weighted_choice.py
# Jim Bagrow
# Last Modified: 2012-11-06
import random, bisect
import pylab
def weighted_choice(choices, num_draws=1):
"""Make biased draws (w/o replacement) from choices. Choices should be a
list of 2-ples: choices = [(a, w_a), (b, w_b), ...], where w_x is the
"weight" of choosing value x. These weights are automatically normalized,
so they need not sum to one, but they must be non-negative. Returns list of
draws.
Example:
>>> choices = [ ("H", 0.8), ("T", 0.2) ] # an unfair coin
>>> for flip in range(100):
>>> print weighted_choice(choices)[0]
"""
values, weights = zip(*choices)
total = 0
cumulative_weights = []
for w in weights:
total += w
cumulative_weights.append( total )
if total == 0:
return random.sample(values,num_draws)
if num_draws >= len(cumulative_weights):
return list(values)
draws = set() # won't have duplicates, so not REALLY with replacement...
attempt = 0
while len(draws) < num_draws:
x = random.random() * total
i = bisect.bisect(cumulative_weights, x)
draws.add( values[i] )
if attempt > 5000:
return list(draws)
attempt += 1
return list(draws)
if __name__ == '__main__':
# choose i with probability ~ 1/i:
choices = []
for i in range(1, 15):
choices.append( (i, 1.0/i) )
# do the sampling, count draws:
draw2count = {}
for sample in xrange(10000):
draw = weighted_choice(choices)[0]
try:
draw2count[draw] += 1
except KeyError:
draw2count[draw] = 1
# plot the distribution:
draws = sorted(draw2count.keys())
counts = [ draw2count[d] for d in draws ]
pylab.hold(True)
pylab.loglog(draws,counts, 'o-')
pylab.plot( draws, [5000.0/d for d in draws], 'r' ) # should have slope -1
pylab.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment