Skip to content
{{ message }}

Instantly share code, notes, and snippets.

# Mr4k/speedtest.py

Created Jul 10, 2017
 import numpy as np from bintrees import FastAVLTree as AVLTree import random import time def weightedRandom(weights): """ Draw from a general discrete distribution. :param weights: A dictionary of weights that must sum to one. :return: A random sample from it the distribution defined by the weights. """ #generate a uniform random number from 0 - 1 remainder = random.random() for weight in weights.iteritems(): value, color = weight remainder -= value if remainder <= 0: return color def partitionWeights(weights): """ The preprocessing step. :param weights: A dictionary of weights that must sum to one. :return: A partition used to draw quickly from the distribution. """ boxes = [] numWeights = len(weights) # We use a AVLTree to make our pull/push operations O(log n) tree = AVLTree(weights) for i in xrange(numWeights): smallestValue, smallestColor = tree.pop_min() # O(log n) overfill = 1.0 / numWeights - smallestValue if overfill > 0.00001: largestValue, largestColor = tree.pop_max() # O(log n) largestValue -= overfill if largestValue > 0.00001: tree.insert(largestValue, largestColor) # O(log n) boxes.append((smallestValue, smallestColor, largestColor)) else: boxes.append((smallestValue, smallestColor, "none")) return boxes def drawFromPartition(partition): """ The draw step. :param partition: partition A partition of a distribution into boxes. :return: A sample from the distribution represented by the partition. """ numBoxes = len(partition) i = random.randint(0, numBoxes - 1) value, color1, color2 = partition[i] if random.random() / numBoxes <= value: return color1 else: return color2 #compare in a speed test weights = {} numWeights = 1000 nweights = np.random.rand(numWeights, 1) nweights /= sum(nweights) for i in xrange(numWeights): weights[float(nweights[i])] = i start = time.time() for i in xrange(100000): weightedRandom(weights) end = time.time() print end - start start = time.time() partition = partitionWeights(weights) for i in xrange(100000): drawFromPartition(partition) end = time.time() print end - start
to join this conversation on GitHub. Already have an account? Sign in to comment