Skip to content

Instantly share code, notes, and snippets.

@Mr4k

Mr4k/speedtest.py

Created Jul 10, 2017
Embed
What would you like to do?
import numpy as np
from bintrees import FastAVLTree as AVLTree
import random
import time
def weightedRandom(weights):
"""
Draw from a general discrete distribution.
:param weights: A dictionary of weights that must sum to one.
:return: A random sample from it the distribution defined by the weights.
"""
#generate a uniform random number from 0 - 1
remainder = random.random()
for weight in weights.iteritems():
value, color = weight
remainder -= value
if remainder <= 0:
return color
def partitionWeights(weights):
"""
The preprocessing step.
:param weights: A dictionary of weights that must sum to one.
:return: A partition used to draw quickly from the distribution.
"""
boxes = []
numWeights = len(weights)
# We use a AVLTree to make our pull/push operations O(log n)
tree = AVLTree(weights)
for i in xrange(numWeights):
smallestValue, smallestColor = tree.pop_min() # O(log n)
overfill = 1.0 / numWeights - smallestValue
if overfill > 0.00001:
largestValue, largestColor = tree.pop_max() # O(log n)
largestValue -= overfill
if largestValue > 0.00001:
tree.insert(largestValue, largestColor) # O(log n)
boxes.append((smallestValue, smallestColor, largestColor))
else:
boxes.append((smallestValue, smallestColor, "none"))
return boxes
def drawFromPartition(partition):
"""
The draw step.
:param partition: partition A partition of a distribution into boxes.
:return: A sample from the distribution represented by the partition.
"""
numBoxes = len(partition)
i = random.randint(0, numBoxes - 1)
value, color1, color2 = partition[i]
if random.random() / numBoxes <= value:
return color1
else:
return color2
#compare in a speed test
weights = {}
numWeights = 1000
nweights = np.random.rand(numWeights, 1)
nweights /= sum(nweights)
for i in xrange(numWeights):
weights[float(nweights[i])] = i
start = time.time()
for i in xrange(100000):
weightedRandom(weights)
end = time.time()
print end - start
start = time.time()
partition = partitionWeights(weights)
for i in xrange(100000):
drawFromPartition(partition)
end = time.time()
print end - start
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.