import numpy as np | |
from bintrees import FastAVLTree as AVLTree | |
import random | |
import time | |
def weightedRandom(weights): | |
""" | |
Draw from a general discrete distribution. | |
:param weights: A dictionary of weights that must sum to one. | |
:return: A random sample from it the distribution defined by the weights. | |
""" | |
#generate a uniform random number from 0 - 1 | |
remainder = random.random() | |
for weight in weights.iteritems(): | |
value, color = weight | |
remainder -= value | |
if remainder <= 0: | |
return color | |
def partitionWeights(weights): | |
""" | |
The preprocessing step. | |
:param weights: A dictionary of weights that must sum to one. | |
:return: A partition used to draw quickly from the distribution. | |
""" | |
boxes = [] | |
numWeights = len(weights) | |
# We use a AVLTree to make our pull/push operations O(log n) | |
tree = AVLTree(weights) | |
for i in xrange(numWeights): | |
smallestValue, smallestColor = tree.pop_min() # O(log n) | |
overfill = 1.0 / numWeights - smallestValue | |
if overfill > 0.00001: | |
largestValue, largestColor = tree.pop_max() # O(log n) | |
largestValue -= overfill | |
if largestValue > 0.00001: | |
tree.insert(largestValue, largestColor) # O(log n) | |
boxes.append((smallestValue, smallestColor, largestColor)) | |
else: | |
boxes.append((smallestValue, smallestColor, "none")) | |
return boxes | |
def drawFromPartition(partition): | |
""" | |
The draw step. | |
:param partition: partition A partition of a distribution into boxes. | |
:return: A sample from the distribution represented by the partition. | |
""" | |
numBoxes = len(partition) | |
i = random.randint(0, numBoxes - 1) | |
value, color1, color2 = partition[i] | |
if random.random() / numBoxes <= value: | |
return color1 | |
else: | |
return color2 | |
#compare in a speed test | |
weights = {} | |
numWeights = 1000 | |
nweights = np.random.rand(numWeights, 1) | |
nweights /= sum(nweights) | |
for i in xrange(numWeights): | |
weights[float(nweights[i])] = i | |
start = time.time() | |
for i in xrange(100000): | |
weightedRandom(weights) | |
end = time.time() | |
print end - start | |
start = time.time() | |
partition = partitionWeights(weights) | |
for i in xrange(100000): | |
drawFromPartition(partition) | |
end = time.time() | |
print end - start |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment